Unverified Commit 01a80cd8 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

Hotfix/Colossalai layers (#92)



* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent 0fedef4f
......@@ -3,10 +3,10 @@
import math
import numbers
from contextlib import nullcontext
from typing import Callable, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
......@@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor
from torch import Tensor, dtype
from torch.nn.parameter import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._operation import FusedLayerNormAffineFunction1D
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
@LAYERS.register_module
class Linear1D(torch.nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
parallel_input = get_parallel_input()
if not parallel_input:
self.layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
else:
self.layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, input_: Tensor) -> Tensor:
return self.layer(input_)
@LAYERS.register_module
class Classifier1D(ParallelLayer):
"""RowLinear with given weight"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
if self.has_weight:
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
output = output + self.bias
return output
@LAYERS.register_module
......@@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(True)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
......@@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
......@@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
@LAYERS.register_module
class Embedding1D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
return output
@LAYERS.register_module
class Dropout1D(ParallelLayer):
def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__()
self.parallel_input = get_parallel_input()
self.p = p
self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor:
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
with cm:
output = F.dropout(input_, self.p, self.training, self.inplace)
return output
from ._operation import reduce_by_batch_2d, split_batch_2d
from ._operation import reduce_by_batch_2d, split_tensor_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
__all__ = [
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
]
......@@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
......@@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
......@@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2d(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
return input_
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
return output / reduce_size
return output
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
return input_.clone()
def forward(ctx, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output
def backward(ctx, output_grad):
if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None
......@@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
from ._utils import assert_summa_initialization, get_summa_dim_from_env
......@@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
......@@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
......
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
__all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D'
]
......@@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode)
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
......@@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1])).contiguous()
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1])).contiguous()
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
B_list = [torch.empty_like(B) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
A_list[0].copy_(A)
B_list[0].copy_(B)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
src_a = i + tesseract_dim * row_rank
src_b = i + tesseract_dim * col_rank
src_a = src_a % tesseract_dim
src_b = src_b % tesseract_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
if i != tesseract_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
if opa[cur] is not None:
opa[cur].wait()
if opb[cur] is not None:
opb[cur].wait()
torch.addmm(C, A_list[cur], B_list[cur], out=C)
cur = 1 - cur
src_a += 1
src_b += tesseract_dim
out = C.reshape(out_shape)
if ctx:
......@@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
B_list = [torch.empty_like(B) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
B_list[0].copy_(B)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
B_temp = B.clone()
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
if i != tesseract_dim - 1:
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == col_rank:
C.copy_(C_list[cur])
if opb[cur] is not None:
opb[cur].wait()
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
cur = 1 - cur
src_b += tesseract_dim
src_c += 1
for op in opr:
op.wait()
if tesseract_dim - 2 == col_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == col_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape)
if ctx:
......@@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
A_list[0].copy_(A)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
A_temp = A.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
if i != tesseract_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == row_rank:
C.copy_(C_list[cur])
if opa[cur] is not None:
opa[cur].wait()
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
cur = 1 - cur
src_a += 1
src_c += tesseract_dim
for op in opr:
op.wait()
if tesseract_dim - 2 == row_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == row_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape)
if ctx:
......@@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
......@@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
......@@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2p5d(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
return input_
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
return output / reduce_size
return output
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
return input_.clone()
def forward(ctx, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output
def backward(ctx, output_grad):
if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None
......@@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..base_layer import ParallelLayer
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
split_batch_2p5d)
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
......@@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_size = embed_size
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
self.embed_size_per_partition = embed_size // self.tesseract_dim**2
with seed(ParallelMode.TENSOR):
self.weight = Parameter(
......@@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
with seed(ParallelMode.TENSOR):
......@@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
......@@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
self.padding_idx = padding_idx
self.embed_args = args
......@@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
......@@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
......@@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
if weight is not None:
self.weight = weight
......@@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
......
from ._operation import reduce_by_batch_3d, split_batch_3d
from ._operation import reduce_by_batch_3d, split_tensor_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
__all__ = [
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
]
......@@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def split_batch_3d(input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
dim: int = 0) -> Tensor:
def split_tensor_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
......@@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
class reduce_by_batch_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
def forward(ctx,
input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None, None, None
else:
return output_grad, None, None, None
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
......
......@@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import *
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@LAYERS.register_module
......@@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
......@@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
......
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
__all__ = [
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
]
......@@ -2,11 +2,12 @@
# -*- encoding: utf-8 -*-
import collections.abc
import os
from itertools import repeat
import numpy as np
import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
from colossalai.utils import checkpoint
from torch import Tensor, nn
......@@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, NUM_PARTITIONS, num_partitions)
def get_tensor_parallel_mode():
return os.environ[TENSOR_PARALLEL_MODE]
# From PyTorch internals
......
......@@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch import nn as nn
from .._common_utils import to_2tuple
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
......
......@@ -2,6 +2,7 @@ from torch import nn
from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss
from colossalai.nn.layer.utils import get_tensor_parallel_mode
from .loss_2d import CrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D
......@@ -14,9 +15,10 @@ _parallel_cross_entropy = {
class CrossEntropyLoss(_Loss):
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs):
def __init__(self, reduction: bool = True, *args, **kwargs):
super().__init__()
if tensor_parallel in [None, '1d']:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
reduction = 'mean' if reduction else 'none'
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
else:
......
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy
......@@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss):
self.loss_kwargs = kwargs
def forward(self, logits, targets):
batch_size = targets.size(0)
targets = split_batch_2d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.sum()
loss = reduce_by_batch_2d.apply(loss)
loss /= batch_size
loss = loss.mean()
loss = reduce_by_batch_2d.apply(loss, True)
return loss
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy
......@@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss):
self.loss_kwargs = kwargs
def forward(self, logits, targets):
batch_size = targets.size(0)
targets = split_batch_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.sum()
loss = reduce_by_batch_2p5d.apply(loss)
loss /= batch_size
loss = loss.mean()
loss = reduce_by_batch_2p5d.apply(loss, True)
return loss
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module
class CrossEntropyLoss3D(_Loss):
"""Cross entropy loss for 3D parallelism
......@@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss):
self.loss_kwargs = kwargs
def forward(self, logits, targets):
batch_size = targets.size(0)
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.sum()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode)
loss /= batch_size
loss = loss.mean()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
return loss
......@@ -4,6 +4,7 @@ from ._utils import calc_acc
from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D
from colossalai.nn.layer.utils import get_tensor_parallel_mode
_parallel_accuracy = {
'2d': Accuracy2D,
......@@ -13,9 +14,10 @@ _parallel_accuracy = {
class Accuracy(nn.Module):
def __init__(self, tensor_parallel: str = None):
def __init__(self):
super().__init__()
if tensor_parallel in [None, '1d']:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.acc = calc_acc
else:
self.acc = _parallel_accuracy[tensor_parallel]()
......
import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
from torch import nn
from ._utils import calc_acc
......@@ -11,7 +11,6 @@ class Accuracy2D(nn.Module):
def forward(self, logits, targets):
with torch.no_grad():
targets = split_batch_2d(targets)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d.apply(correct)
return correct
import torch
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
from torch import nn
from ._utils import calc_acc
......@@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module):
def forward(self, logits, targets):
with torch.no_grad():
targets = split_batch_2p5d(targets)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d.apply(correct)
return correct
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment