Commit 404ecbdc authored by zbian's avatar zbian
Browse files

Migrated project

parent 2ebaefc5
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class ParallelLayer(nn.Module):
def __init__(self):
super().__init__()
self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(
ParallelMode.DATA)
self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size(
ParallelMode.DATA)
self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank(
ParallelMode.TENSOR)
self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(
ParallelMode.TENSOR)
self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
from .layers import Linear1D_Col, Linear1D_Row
__all__ = [
'Linear1D_Col', 'Linear1D_Row',
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .._common_utils import divide
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
:param in_features: first dimension of matrix A.
:type in_features: int
:param output_size: second dimension of matrix A.
:type output_size: int
:param bias: If true, add bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
"""
def __init__(self,
in_features: int,
output_size: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False):
super().__init__()
# Keep input parameters
self.input_size = in_features
self.output_size = output_size
self.gather_output = gather_output
self.skip_bias_add = not bias
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
**factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
**factory_kwargs))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
@LAYERS.register_module
class Linear1D_Row(ParallelLayer):
""" Linear layer with row parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = not bias
# Divide the weight matrix along the last dimension.
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
self.input_size_per_partition = divide(in_features, world_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.out_features,
self.input_size_per_partition,
**factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.out_features,
**factory_kwargs
))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def reset_parameters(self) -> None:
init.xavier_normal_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(
input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:
output = output + self.bias
return output
from ._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D, Add_Bias_2D, matmul_2d
from ._transformer import TransformerMLP2D, TransformerSelfAttention2D, TransformerLayer2D
from ._vit import ViTMLP2D, ViTSelfAttention2D, ViTHead2D, ViTPatchEmbedding2D, ViTTokenFuser2D, ViTInputSplitter2D
from .layers import Linear2D, LayerNorm2D
__all__ = [
'Matmul_AB_2D', 'Matmul_ABT_2D', 'Matmul_ATB_2D', 'Add_Bias_2D', 'matmul_2d',
'TransformerMLP2D', 'TransformerSelfAttention2D', 'TransformerLayer2D',
'ViTMLP2D', 'ViTSelfAttention2D', 'ViTHead2D', 'ViTPatchEmbedding2D', 'ViTTokenFuser2D', 'ViTInputSplitter2D',
'Linear2D', 'LayerNorm2D'
]
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def matmul_2d(a,
b,
summa_dim,
out_shape,
row_rank=None,
col_rank=None,
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row, defaults to None
:type row_rank: int, optional
:param col_rank: the rank of column, defaults to None
:type col_rank: int, optional
:param row_parallel_mode: row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW
:type row_parallel_mode: str, optional
:param col_parallel_mode: column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL
:type col_parallel_mode: str, optional
:return: :math:`C = AB`
:rtype: torch.tensor
"""
if row_rank is None:
row_rank = gpc.get_local_rank(col_parallel_mode)
if col_rank is None:
col_rank = gpc.get_local_rank(row_parallel_mode)
data_parallel_rank = 0 if not gpc.is_initialized(
ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = summa_dim ** 2
return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size
)
class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
# A: [b / q, s, h / q] -> [(b * s) / q, h / q]
# B: [h / q, s / q]
# C: [b / q, s, s / q] -> [(b * s) / q, s / q]
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
B_temp = B.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2D.forward(
None,
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.forward(
None,
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None,
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.forward(
None,
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank,
group=gpc.get_group(col_parallel_mode))
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
if skip_bias_add:
return bias_temp
else:
output = input + bias_temp
return output
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2D(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.normalized_shape = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
return output
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=gpc.get_group(row_parallel_mode))
output_grad_sum /= ctx.normalized_shape
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.normalized_shape
input_grad = output_grad.clone()
input_grad -= x * output_grad_mul_x_sum
input_grad -= output_grad_sum
input_grad *= Var_x
return input_grad, None, None, None, None, None
# class Sum_2D(torch.autograd.Function):
#
# @staticmethod
# def forward(ctx: Any,
# inputs: Tensor,
# dim: int,
# summa_dim: int,
# row_parallel_mode: ParallelMode,
# keepdim: bool = False) -> Tensor:
# # input: [b/q, s, h/q]
# empty_cache()
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode))
# return out
#
# @staticmethod
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
class _ViT_Split_Input_2D(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
inputs: Tensor,
batch_size: int,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
# inputs: [b, s, h/q]
# output: [b/q, s, h/q]
ctx.BATCH_SIZE = batch_size
ctx.summa_dim = summa_dim
ctx.col_parallel_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
output = torch.chunk(inputs, summa_dim, dim=0)[row_rank]
output = output.clone()
return output
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]
grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.col_parallel_mode))
return grads, None, None, None
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from .layers import Linear2D, LayerNorm2D
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2D(
in_features,
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
int(mlp_ratio * in_features),
in_features,
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2D(ParallelLayer):
"""Self attention layer for 2D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output
import os
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.process_group_initializer.initializer_2d import SUMMA_DIM
from colossalai.core import global_context as gpc
def get_summa_dim_from_env() -> int:
try:
summa_dim = os.environ[SUMMA_DIM]
summa_dim = int(summa_dim)
assert summa_dim > 0, 'SUMMA_DIM must be larger than zero'
return summa_dim
except KeyError as e:
raise EnvironmentError('SUMMA_DIM is not found in the current environment, '
'please make sure that you have used the correct process group initializer')
def assert_summa_initialization():
assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \
gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \
'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer'
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_Input_2D
from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class ViTMLP2D(ParallelLayer):
"""MLP layer for 2D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention2D(ParallelLayer):
"""Self-attention layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead2D(ParallelLayer):
"""Output layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.summa_dim
with seed(ParallelMode.TENSOR):
# ensure the partitions are initialized differently
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(self.proj.weight, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
dist.broadcast(self.proj.bias, src=ranks_in_col[0],
group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.summa_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.summa_dim))
# move to cuda before broadcast
self.to(get_current_device())
# sync param in both forward and backward
_cls_token = self.cls_token.view(-1)
_pos_embed = self.pos_embed.view(-1)
self._param = torch.cat([_cls_token, _pos_embed], dim=0)
self._broadcast_params(self._param)
self._param.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
ranks_in_col = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)
col_group = gpc.get_group(ParallelMode.PARALLEL_2D_COL)
dist.broadcast(param, src=ranks_in_col[0],
group=col_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
grad = grad / self.summa_dim
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_Input_2D.apply(
x,
batch_size,
self.summa_dim,
ParallelMode.PARALLEL_2D_COL
)
import math
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear2D(ParallelLayer):
""" Linear layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
# parallel settings
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(
self.in_features, self.summa_dim)
self.hidden_size_per_partition = divide(
self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def reset_parameters(self) -> None:
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
init.uniform_(self.weight, -bound, bound)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2D.apply(
x,
self.weight,
self.summa_dim,
out_shape,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size)
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output, bias
else:
output = Add_Bias_2D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output
else:
return output
@LAYERS.register_module
class LayerNorm2D(ParallelLayer):
r"""Layer Normalization for 2D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
super().__init__()
# layer norm config
self.normalized_shape = normalized_shape
self.variance_epsilon = eps
# parallel setting
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.summa_dim)
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL)
bias = Add_Bias_2D.apply(
None, self.beta, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2D.apply(
None, self.gamma, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = torch.addcmul(bias, scale, output)
return output
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D,
ViTInputSplitter2p5D)
from .layers import Linear2p5D, LayerNorm2p5D
__all__ = [
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D',
'Linear2p5D', 'LayerNorm2p5D'
]
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, empty_cache
def get_parallel_group(parallel_mode: ParallelMode):
return gpc.get_group(parallel_mode)
def get_global_rank():
return gpc.get_global_rank()
def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode)
class Matmul_AB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
# B: [h / dq, s / q]
# C: [b / dq, s, s / q] -> [(b * s) / dq, s / q]
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
A_temp = A.clone()
B_temp = B.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=get_parallel_group(row_parallel_mode))
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=get_parallel_group(col_parallel_mode))
torch.addmm(C, A_temp, B_temp, out=C)
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward(
None,
output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.forward(
None,
A, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Matmul_ABT_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
B_temp = B.clone()
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_AB_2p5D.forward(
None,
output_grad, B,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.forward(
None,
output_grad, A,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Matmul_ATB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
tesseract_dep: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int):
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
empty_cache()
if ctx:
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
A_temp = A.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=get_parallel_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=get_parallel_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2p5D.forward(
None,
B, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2p5D.forward(
None,
A, output_grad,
ctx.tesseract_dim, ctx.tesseract_dep, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.dep_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2p5D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
tesseract_dep: int,
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.dep_rank = dep_rank
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
ctx.bias = skip_bias_add
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
if skip_bias_add:
return bias_temp
else:
output = input + bias_temp
return output
@staticmethod
def backward(ctx, output_grad):
row_rank = ctx.row_rank
col_rank = ctx.col_rank
dep_rank = ctx.dep_rank
tesseract_dim = ctx.tesseract_dim
tesseract_dep = ctx.tesseract_dep
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None
else:
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
else:
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
dep_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.hidden_size = hidden_size
output = input * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.dep_parallel_mode = dep_parallel_mode
return output
@staticmethod
def backward(ctx, output_grad):
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
dep_parallel_mode = ctx.dep_parallel_mode
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with torch.no_grad():
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=get_parallel_group(row_parallel_mode))
output_grad_sum /= ctx.hidden_size
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.hidden_size
input_grad = output_grad.clone()
input_grad -= x * output_grad_mul_x_sum
input_grad -= output_grad_sum
input_grad *= Var_x
return input_grad, None, None, None, None, None, None
class Sum_2p5D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
def forward(ctx,
inputs,
dim,
tesseract_dim,
row_parallel_mode,
keepdim=False):
# input: [b/q, s, h/q]
empty_cache()
ctx.save_for_backward(inputs)
# sum: [b/q, s]
out = torch.sum(inputs, dim=dim, keepdim=keepdim)
torch.distributed.all_reduce(
out, group=gpc.get_group(row_parallel_mode))
return out
@staticmethod
def backward(ctx, output_grad):
with torch.no_grad():
inputs = ctx.saved_tensors
input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
return input_grad, None, None, None, None, None
class _ViT_Split_2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, batch_size,
tesseract_dim, tesseract_dep,
xz_parallel_mode):
# inputs: [b, s, h/q]
# output: [b/dq, s, h/q]
empty_cache()
ctx.batch_size = batch_size
ctx.tesseract_dim = tesseract_dim
ctx.tesseract_dep = tesseract_dep
ctx.xz_parallel_mode = xz_parallel_mode
xz_rank = gpc.get_local_rank(xz_parallel_mode)
output = torch.chunk(inputs, tesseract_dep *
tesseract_dim, dim=0)[xz_rank]
output = output.clone()
return output
@staticmethod
def backward(ctx, output_grad):
# output_grad: [b/dq, s, h/q]
# grads: [b, s, h/q]
# *
grads_shape = (ctx.batch_size,) + output_grad.shape[1:]
grads = torch.empty(grads_shape,
dtype=output_grad.dtype,
device=get_current_device())
dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
output_grad.contiguous(),
group=get_parallel_group(ctx.xz_parallel_mode))
return grads, None, None, None, None
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide
from colossalai.registry import LAYERS
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D, LayerNorm2p5D
from .._common_utils import ACT2FN
@LAYERS.register_module
class TransformerMLP2p5D(nn.Module):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.in_features = in_features
# Project to h * mlp_ratio.
self.dense_1 = Linear2p5D(
in_features,
mlp_ratio * in_features,
dtype=dtype
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2p5D(
mlp_ratio * in_features,
in_features,
dtype=dtype
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2p5D(nn.Module):
"""Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2p5D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2p5D(nn.Module):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
act_func='gelu',
mlp_ratio=4,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2p5D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2p5D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output
import os
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
def get_tesseract_dim_dep_from_env():
try:
tesseract_dim = int(os.environ['TESSERACT_DIM'])
tesseract_dep = int(os.environ['TESSERACT_DEP'])
assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero'
assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero'
return tesseract_dim, tesseract_dep
except KeyError as e:
raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, '
'please make sure that you have used the correct process group initializer')
def assert_tesseract_initialization():
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer'
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.vanilla_vision_transformer.layers import to_2tuple
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import _ViT_Split_2p5D
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D
from .._common_utils import ACT2FN, divide, CheckpointModule
from .._common_utils import set_tensor_parallel_attribute
@LAYERS.register_module
class ViTMLP2p5D(CheckpointModule):
"""MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False
):
super().__init__(checkpoint=checkpoint)
assert_tesseract_initialization()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
# Project to mlp_ratio * h.
self.dense_1 = Linear2p5D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2p5D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
@LAYERS.register_module
class ViTSelfAttention2p5D(CheckpointModule):
"""Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
checkpoint: bool = False
):
super().__init__(checkpoint=checkpoint)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
@LAYERS.register_module
class ViTHead2p5D(nn.Module):
"""Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.linear = Linear2p5D(
hidden_size,
num_classes,
dtype=dtype,
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2p5D(nn.Module):
""" 2.5D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // self.tesseract_dim # *
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
)
# move self to cuda before sync
self.to(get_current_device())
# sync
self._broadcast_conv_params()
self.proj.weight.register_hook(self._sync_grad_during_backward)
self.proj.bias.register_hook(self._sync_grad_during_backward)
def _broadcast_conv_params(self) -> None:
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.proj.weight, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.proj.bias, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
def _sync_grad_during_backward(self, grad: Tensor) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser2p5D(nn.Module):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim // self.tesseract_dim)) # *
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + 1, self.embed_dim // self.tesseract_dim)) # *
# move to cuda before broadcast
self.to(get_current_device())
self._broadcast_params()
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _broadcast_params(self) -> None:
" broadcast to all column ranks for data consistency "
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(self.cls_token, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
dist.broadcast(self.pos_embed, src=xz_rank[0],
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTInputSplitter2p5D(nn.Module):
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
batch_size = x.size(0)
return _ViT_Split_2p5D.apply(
x,
batch_size,
self.tesseract_dim,
self.tesseract_dep,
ParallelMode.PARALLEL_2P5D_XZ,
)
import math
import torch
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from ..base_layer import ParallelLayer
@LAYERS.register_module
class Linear2p5D(ParallelLayer):
"""Linear layer for 2.5D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
# parallel setting
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
self.hidden_size_per_partition = divide(
out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def reset_parameters(self) -> None:
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
init.uniform_(self.weight, -bound, bound)
# init bias
if self.bias is not None:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
self.tesseract_dep,
out_shape,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2p5D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output, bias
else:
output = Add_Bias_2p5D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
return output
else:
return output
@LAYERS.register_module
class LayerNorm2p5D(ParallelLayer):
r"""Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
super().__init__()
# layer norm config
self.normalized_shape = normalized_shape
self.variance_epsilon = eps
# parallel setting
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(
normalized_shape, self.tesseract_dim) # *
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP)
bias = Add_Bias_2p5D.apply(
None, self.beta, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2p5D.apply(
None, self.gamma, self.partitioned_partition,
self.tesseract_dim, self.tesseract_dep,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = torch.addcmul(bias, scale, output)
return output
from ._operation import Matmul_ABT_3D, Matmul_ATB_3D, Matmul_AB_3D, Mul_3D, Sum_3D, Add_3D, Reduce_3D
from ._vit import ViTHead3D, ViTMLP3D, ViTPatchEmbedding3D, ViTSelfAttention3D
from .layers import Linear3D, LayerNorm3D
__all__ = [
'Matmul_ABT_3D', 'Matmul_ATB_3D', 'Matmul_AB_3D', 'Mul_3D', 'Sum_3D', 'Add_3D', 'Reduce_3D',
'ViTHead3D', 'ViTMLP3D', 'ViTPatchEmbedding3D', 'ViTSelfAttention3D',
'Linear3D', 'LayerNorm3D'
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Tuple
import torch
import torch.distributed as dist
from colossalai.communication import all_gather, reduce_scatter, scatter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import empty_cache, get_current_device
from torch import Tensor
class Matmul_AB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
# A: [m/q^2, n, k/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, h/q]
empty_cache()
ctx.save_for_backward(A, B)
assert A.shape[-1] == B.shape[0], \
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape)
A_temp = all_gather(A, input_dim, input_parallel_mode)
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
C = torch.matmul(A_temp, B_temp)
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.C_dim,
ctx.B_dim, ctx.A_dim)
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth,
ctx.A_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.A_dim,
ctx.C_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Matmul_ABT_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
# A: [m/q^2, n, h/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, k/q]
empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
C = torch.matmul(A_temp, B_temp.transpose(0, 1))
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.C_dim,
ctx.B_dim, ctx.A_dim)
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth,
ctx.C_group_parallel_mode,
ctx.A_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.C_dim,
ctx.A_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Matmul_ATB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = 0,
output_dim: int = -1) -> Tensor:
# A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q]
# C: [k/q, h/q^2]
empty_cache()
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
A_temp = A_temp.reshape(-1, A.shape[-1])
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
B_temp = B_temp.reshape(-1, B.shape[-1])
C = torch.matmul(A_temp.transpose(0, 1), B_temp)
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth,
ctx.B_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.B_dim,
ctx.C_dim, ctx.A_dim)
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth,
ctx.A_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.A_dim,
ctx.C_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Add_3D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
bias_temp = bias.clone()
dist.broadcast(bias_temp,
src=src_rank,
group=gpc.get_group(input_parallel_mode))
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
out = input_ + bias_temp
ctx.depth = depth
ctx.src_rank = src_rank
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
# [h/q]
grad = torch.sum(output_grad,
dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
dist.reduce(bias_grad,
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return output_grad, bias_grad, None, None, None, None
class Mul_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A * b`
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
# [h/q^2]
bias_temp = bias.clone()
dist.broadcast(bias_temp,
src=src_rank,
group=gpc.get_group(input_parallel_mode))
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
empty_cache()
ctx.save_for_backward(input_, bias_temp)
out = torch.mul(input_, bias_temp)
ctx.depth = depth
ctx.src_rank = src_rank
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
input_, bias = ctx.saved_tensors
# [m/q^2, n, h/q]
input_grad = torch.mul(output_grad, bias)
# [h/q]
grad = torch.mul(output_grad, input_)
grad = torch.sum(grad,
dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
dist.reduce(bias_grad,
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return input_grad, bias_grad, None, None, None, None
class Sum_3D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
def forward(ctx: Any,
input_: Tensor,
dim: int,
depth: int,
parallel_mode: ParallelMode,
keepdim: bool = False) -> Tensor:
# input: [m/q^2, n, h/q]
out = torch.sum(input_, dim=dim, keepdim=keepdim)
dist.all_reduce(out, group=gpc.get_group(parallel_mode))
ctx.input_shape = input_.shape
ctx.depth = depth
ctx.group = parallel_mode
ctx.dim = dim
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
output_grad = output_grad.contiguous()
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group))
if len(output_grad.shape) < len(ctx.input_shape):
output_grad = torch.unsqueeze(output_grad, ctx.dim)
dims = [1 for _ in range(len(output_grad.shape))]
dims[ctx.dim] = ctx.input_shape[ctx.dim]
input_grad = output_grad.repeat(tuple(dims))
return input_grad, None, None, None, None, None
class Reduce_3D(torch.autograd.Function):
"""Reduce input tensors
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, depth: int,
parallel_mode: ParallelMode) -> Tensor:
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_.clone()
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
class Slice_3D(torch.autograd.Function):
"""Slice input tensor
"""
@staticmethod
def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
parallel_mode: ParallelMode) -> Tensor:
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
ctx.depth = depth
ctx.parallel_mode = parallel_mode
ctx.dim = dim
ctx.input_shape = input_.shape
return out
@staticmethod
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
input_grad.reshape(ctx.input_shape)
return input_grad, None, None, None
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from colossalai.constants import DEPTH_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
def get_depth_from_env() -> int:
try:
depth = os.environ[DEPTH_3D]
depth = int(depth)
assert depth > 0, 'DEPTH must be greater than zero'
return depth
except KeyError as e:
raise EnvironmentError(
'DEPTH is not found in the current environment, '
'please make sure that you have used the correct process group initializer'
)
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
}
res = chr(
ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT
elif res == 'B':
return ParallelMode.PARALLEL_3D_WEIGHT
elif res == 'C':
return ParallelMode.PARALLEL_3D_OUTPUT
def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank()
if rank == 0:
print(tensor.shape)
assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape)
import math
from typing import Tuple
import torch
import torch.distributed as dist
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint, get_current_device
from torch import Tensor, dtype, nn
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute
from ..vanilla_vision_transformer.layers import to_2tuple
from ._utils import get_depth_from_env
from .layers import Linear3D
@LAYERS.register_module
class ViTPatchEmbedding3D(nn.Module):
""" 3D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: dimension of embedding
:type embed_size: int
:param drop_prob: dropout probability
:type drop_prob: float
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
drop_prob: float,
flatten: bool = True):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.embed_size = embed_size
self.embed_size_per_partition = divide(self.embed_size, self.depth)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_size_per_partition,
kernel_size=patch_size,
stride=patch_size)
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_size_per_partition))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1,
self.embed_size_per_partition))
self.pos_drop = nn.Dropout(drop_prob)
self._sync_parameters()
self.proj.weight.register_hook(self._sync_grad_hook)
self.proj.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def _sync_parameters(self):
self.to(get_current_device())
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
dist.broadcast(self.proj.weight,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
dist.broadcast(self.proj.bias,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
dist.broadcast(self.proj.weight,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
dist.broadcast(self.proj.bias,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
set_tensor_parallel_attribute(self.proj.weight)
set_tensor_parallel_attribute(self.proj.bias)
set_tensor_parallel_attribute(self.cls_token)
set_tensor_parallel_attribute(self.pos_embed)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
dist.all_reduce(grad, group=gpc.get_group(self.weight_parallel_mode))
return grad
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# split a partition from embedded states
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
# add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTSelfAttention3D(nn.Module):
"""Self-attention layer for 3D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_probs_dropout_prob: dropout probability for attention layers
:type attention_probs_dropout_prob: bool
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: bool
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_probs_dropout_prob: float,
hidden_dropout_prob: float,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.depth)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
self.query_key_value = Linear3D(self.hidden_size,
3 * self.hidden_size,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = Linear3D(self.hidden_size,
self.hidden_size,
self.output_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
3,
dim=-1)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTMLP3D(nn.Module):
"""[summary]
:param hidden_size: hidden size
:type hidden_size: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param hidden_act: activation function for hidden layers
:type hidden_act: str
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
mlp_ratio: int,
hidden_dropout_prob: float,
hidden_act: str = 'gelu',
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.dense_1 = Linear3D(self.hidden_size,
self.mlp_ratio * self.hidden_size,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.activation_func = ACT2FN[hidden_act]
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
self.hidden_size,
self.output_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead3D(nn.Module):
"""Output layer for 3D parallel Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
dtype: dtype = None,
bias: bool = True):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = ParallelMode.PARALLEL_3D_INPUT
self.weight_parallel_mode = ParallelMode.PARALLEL_3D_WEIGHT
self.output_parallel_mode = ParallelMode.PARALLEL_3D_OUTPUT
self.in_features = in_features
self.num_classes = num_classes
out_features = math.ceil(self.num_classes /
(self.depth**2)) * (self.depth**2)
self.num_classes_per_partition = divide(self.num_classes, self.depth)
self.linear = Linear3D(self.in_features,
out_features,
self.input_parallel_mode,
self.weight_parallel_mode,
dtype=dtype,
bias=bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.linear.groups_for_next_layer()
def forward(self, x: Tensor) -> Tensor:
# [b/q^2, s, h/q] --> [b/q^2, h/q]
x = x[:, 0]
# [b/q^2, h/q] --> [b/q^2, c/q]
x = self.linear(x)
return x[:, :self.num_classes_per_partition]
def extra_repr(self):
return 'in_features={}, num_classes={}'.format(self.in_features,
self.num_classes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment