Unverified Commit 01a80cd8 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

Hotfix/Colossalai layers (#92)



* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent 0fedef4f
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
import math import math
import numbers import numbers
from contextlib import nullcontext
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
...@@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc ...@@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import Tensor from torch import Tensor, dtype
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._operation import FusedLayerNormAffineFunction1D from ._operation import FusedLayerNormAffineFunction1D
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward) from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
@LAYERS.register_module
class Linear1D(torch.nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
parallel_input = get_parallel_input()
if not parallel_input:
self.layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
else:
self.layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, input_: Tensor) -> Tensor:
return self.layer(input_)
@LAYERS.register_module
class Classifier1D(ParallelLayer):
"""RowLinear with given weight"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
if self.has_weight:
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
output = output + self.bias
return output
@LAYERS.register_module @LAYERS.register_module
...@@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer): ...@@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(True)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
...@@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer): ...@@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
...@@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module): ...@@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
@LAYERS.register_module
class Embedding1D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
return output
@LAYERS.register_module
class Dropout1D(ParallelLayer):
def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__()
self.parallel_input = get_parallel_input()
self.p = p
self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor:
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
with cm:
output = F.dropout(input_, self.p, self.training, self.inplace)
return output
from ._operation import reduce_by_batch_2d, split_batch_2d from ._operation import reduce_by_batch_2d, split_tensor_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
__all__ = [ __all__ = [
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
] ]
...@@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple ...@@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function): ...@@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None return grad, None, None
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
...@@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2d(torch.autograd.Function): class reduce_by_batch_2d(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """All-reduce the input from the model parallel region."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
return input_ if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
return output / reduce_size
return output
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_): def forward(ctx, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
return input_.clone() ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, output_grad):
return grad_output if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None
...@@ -13,9 +13,9 @@ from colossalai.utils import get_current_device ...@@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d) from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
from ._utils import assert_summa_initialization, get_summa_dim_from_env from ._utils import assert_summa_initialization, get_summa_dim_from_env
...@@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
...@@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer): ...@@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
......
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
__all__ = [ __all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D' 'Embedding2p5D'
] ]
...@@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode): ...@@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode) return gpc.get_local_rank(parallel_mode)
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
...@@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function): ...@@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
A_shape = A.shape A_shape = A.shape
A = A.reshape((-1, A_shape[-1])).contiguous() A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])).contiguous() B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)] # use circular buffer to store the communication tensor
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)] # 2 is enough for all cases
A_list.insert(gpc.get_local_rank(row_parallel_mode), A) A_list = [torch.empty_like(A) for _ in range(2)]
B_list.insert(gpc.get_local_rank(col_parallel_mode), B) B_list = [torch.empty_like(B) for _ in range(2)]
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait() row_group = gpc.get_group(row_parallel_mode)
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True) col_group = gpc.get_group(col_parallel_mode)
for op in [op_a, op_b]:
op.wait() src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
A_list[0].copy_(A)
B_list[0].copy_(B)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim): for i in range(tesseract_dim):
src_a = i + tesseract_dim * row_rank if i != tesseract_dim - 1:
src_b = i + tesseract_dim * col_rank A_list[1 - cur].copy_(A)
src_a = src_a % tesseract_dim opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
src_b = src_b % tesseract_dim B_list[1 - cur].copy_(B)
A_temp = A_list[src_a] opb[1 - cur] = dist.broadcast(B_list[1 - cur],
B_temp = B_list[src_b] src=src_b + tesseract_dim,
torch.addmm(C, A_temp, B_temp, out=C) group=col_group,
async_op=True)
if opa[cur] is not None:
opa[cur].wait()
if opb[cur] is not None:
opb[cur].wait()
torch.addmm(C, A_list[cur], B_list[cur], out=C)
cur = 1 - cur
src_a += 1
src_b += tesseract_dim
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
...@@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function): ...@@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
C_shape = (A.shape[0], B.shape[0]) C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
B_list = [torch.empty_like(B) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
B_list[0].copy_(B)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim): for i in range(tesseract_dim):
B_temp = B.clone() if i != tesseract_dim - 1:
src_b = col_rank + i * tesseract_dim + dep_rank * ( B_list[1 - cur].copy_(B)
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ opb[1 - cur] = dist.broadcast(B_list[1 - cur],
pipeline_parallel_rank * tensor_parallel_size src=src_b + tesseract_dim,
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode)) group=col_group,
C_temp = torch.matmul(A, B_temp.transpose(0, 1)) async_op=True)
src_c = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ if opr[cur] is not None:
pipeline_parallel_rank * tensor_parallel_size opr[cur].wait()
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode)) if i - 2 == col_rank:
if i == col_rank: C.copy_(C_list[cur])
C = C_temp.clone()
if opb[cur] is not None:
opb[cur].wait()
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
cur = 1 - cur
src_b += tesseract_dim
src_c += 1
for op in opr:
op.wait()
if tesseract_dim - 2 == col_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == col_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
...@@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function): ...@@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
C_shape = (A.shape[-1], B.shape[-1]) C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
A_list[0].copy_(A)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
cur = 0
for i in range(tesseract_dim): for i in range(tesseract_dim):
A_temp = A.clone() if i != tesseract_dim - 1:
src_a = i + row_rank * tesseract_dim + dep_rank * ( A_list[1 - cur].copy_(A)
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode)) if opr[cur] is not None:
C_temp = torch.matmul(A_temp.transpose(0, 1), B) opr[cur].wait()
src_c = col_rank + i * tesseract_dim + dep_rank * ( if i - 2 == row_rank:
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ C.copy_(C_list[cur])
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode)) if opa[cur] is not None:
if i == row_rank: opa[cur].wait()
C = C_temp.clone()
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
cur = 1 - cur
src_a += 1
src_c += tesseract_dim
for op in opr:
op.wait()
if tesseract_dim - 2 == row_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == row_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape) out = C.reshape(out_shape)
if ctx: if ctx:
...@@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function): ...@@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
bias_temp = bias.clone() bias_temp = bias.clone()
else: else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * ( src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
...@@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function): ...@@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None return grad, None, None
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
...@@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2p5d(torch.autograd.Function): class reduce_by_batch_2p5d(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """All-reduce the input from the model parallel region."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
return input_ if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
return output / reduce_size
return output
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_): def forward(ctx, input_, reduce_mean: bool = False):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL)) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
return input_.clone() ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, output_grad):
return grad_output if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None
...@@ -13,10 +13,9 @@ from colossalai.utils import get_current_device ...@@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d, from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
split_batch_2p5d) from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env) from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
...@@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
self.embed_size = embed_size self.embed_size = embed_size
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2) self.embed_size_per_partition = embed_size // self.tesseract_dim**2
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.weight = Parameter( self.weight = Parameter(
...@@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
self._set_tensor_parallel_attribute() self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self): def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
...@@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
...@@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer): ...@@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embed_dim = embedding_dim
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2) embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embed_args = args self.embed_args = args
...@@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer): ...@@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
...@@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer): ...@@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
...@@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer): ...@@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension # partitioning dimension
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2) self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
...@@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer): ...@@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
if self.has_weight: if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2) set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
......
from ._operation import reduce_by_batch_3d, split_batch_3d from ._operation import reduce_by_batch_3d, split_tensor_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
__all__ = [ __all__ = [
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' 'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
] ]
...@@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function): ...@@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
def split_batch_3d(input_: Tensor, def split_tensor_3d(input_: Tensor,
input_parallel_mode: ParallelMode, dim: int = 0,
weight_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
dim: int = 0) -> Tensor: weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
...@@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor, ...@@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
class reduce_by_batch_3d(torch.autograd.Function): class reduce_by_batch_3d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor: def forward(ctx,
input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
output = all_reduce(input_, input_parallel_mode) output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode) output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone() return output.clone()
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None, None, None
else:
return output_grad, None, None, None
class broadcast_weight_3d_from_diagonal(torch.autograd.Function): class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
......
...@@ -17,9 +17,9 @@ from colossalai.utils import get_current_device ...@@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch.nn import Parameter from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import * from ._operation import *
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group) from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@LAYERS.register_module @LAYERS.register_module
...@@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
self.pos_embed.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode) self.weight_parallel_mode, self.output_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
...@@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer): ...@@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode) self.weight_parallel_mode, self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
......
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
__all__ = [
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
]
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import collections.abc import collections.abc
import os
from itertools import repeat from itertools import repeat
import numpy as np import numpy as np
import torch import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from torch import Tensor, nn from torch import Tensor, nn
...@@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions): ...@@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, NUM_PARTITIONS, num_partitions) setattr(param, NUM_PARTITIONS, num_partitions)
def get_tensor_parallel_mode():
return os.environ[TENSOR_PARALLEL_MODE]
# From PyTorch internals # From PyTorch internals
......
...@@ -9,7 +9,7 @@ from colossalai.utils import get_current_device ...@@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype from torch import Tensor, dtype
from torch import nn as nn from torch import nn as nn
from .._common_utils import to_2tuple from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
......
...@@ -2,6 +2,7 @@ from torch import nn ...@@ -2,6 +2,7 @@ from torch import nn
from torch.nn.modules.loss import * from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.nn.layer.utils import get_tensor_parallel_mode
from .loss_2d import CrossEntropyLoss2D from .loss_2d import CrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D from .loss_2p5d import CrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D from .loss_3d import CrossEntropyLoss3D
...@@ -14,9 +15,10 @@ _parallel_cross_entropy = { ...@@ -14,9 +15,10 @@ _parallel_cross_entropy = {
class CrossEntropyLoss(_Loss): class CrossEntropyLoss(_Loss):
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs): def __init__(self, reduction: bool = True, *args, **kwargs):
super().__init__() super().__init__()
if tensor_parallel in [None, '1d']: tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
reduction = 'mean' if reduction else 'none' reduction = 'mean' if reduction else 'none'
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
else: else:
......
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
...@@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss): ...@@ -20,11 +20,8 @@ class CrossEntropyLoss2D(_Loss):
self.loss_kwargs = kwargs self.loss_kwargs = kwargs
def forward(self, logits, targets): def forward(self, logits, targets):
batch_size = targets.size(0) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
targets = split_batch_2d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.sum() loss = loss.mean()
loss = reduce_by_batch_2d.apply(loss) loss = reduce_by_batch_2d.apply(loss, True)
loss /= batch_size
return loss return loss
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
...@@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss): ...@@ -19,11 +19,8 @@ class CrossEntropyLoss2p5D(_Loss):
self.loss_kwargs = kwargs self.loss_kwargs = kwargs
def forward(self, logits, targets): def forward(self, logits, targets):
batch_size = targets.size(0) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
targets = split_batch_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.sum() loss = loss.mean()
loss = reduce_by_batch_2p5d.apply(loss) loss = reduce_by_batch_2p5d.apply(loss, True)
loss /= batch_size
return loss return loss
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss3D(_Loss): class CrossEntropyLoss3D(_Loss):
"""Cross entropy loss for 3D parallelism """Cross entropy loss for 3D parallelism
...@@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss): ...@@ -28,11 +27,8 @@ class CrossEntropyLoss3D(_Loss):
self.loss_kwargs = kwargs self.loss_kwargs = kwargs
def forward(self, logits, targets): def forward(self, logits, targets):
batch_size = targets.size(0) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.sum() loss = loss.mean()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode) loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode, True)
loss /= batch_size
return loss return loss
...@@ -4,6 +4,7 @@ from ._utils import calc_acc ...@@ -4,6 +4,7 @@ from ._utils import calc_acc
from .accuracy_2d import Accuracy2D from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D from .accuracy_3d import Accuracy3D
from colossalai.nn.layer.utils import get_tensor_parallel_mode
_parallel_accuracy = { _parallel_accuracy = {
'2d': Accuracy2D, '2d': Accuracy2D,
...@@ -13,9 +14,10 @@ _parallel_accuracy = { ...@@ -13,9 +14,10 @@ _parallel_accuracy = {
class Accuracy(nn.Module): class Accuracy(nn.Module):
def __init__(self, tensor_parallel: str = None): def __init__(self):
super().__init__() super().__init__()
if tensor_parallel in [None, '1d']: tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.acc = calc_acc self.acc = calc_acc
else: else:
self.acc = _parallel_accuracy[tensor_parallel]() self.acc = _parallel_accuracy[tensor_parallel]()
......
import torch import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d
from torch import nn from torch import nn
from ._utils import calc_acc from ._utils import calc_acc
...@@ -11,7 +11,6 @@ class Accuracy2D(nn.Module): ...@@ -11,7 +11,6 @@ class Accuracy2D(nn.Module):
def forward(self, logits, targets): def forward(self, logits, targets):
with torch.no_grad(): with torch.no_grad():
targets = split_batch_2d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d.apply(correct) correct = reduce_by_batch_2d.apply(correct)
return correct return correct
import torch import torch
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d
from torch import nn from torch import nn
from ._utils import calc_acc from ._utils import calc_acc
...@@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module): ...@@ -11,7 +11,6 @@ class Accuracy2p5D(nn.Module):
def forward(self, logits, targets): def forward(self, logits, targets):
with torch.no_grad(): with torch.no_grad():
targets = split_batch_2p5d(targets)
correct = calc_acc(logits, targets) correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d.apply(correct) correct = reduce_by_batch_2p5d.apply(correct)
return correct return correct
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment