Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from colossalai.tensor import ColoTensor
from colossalai.context.singleton_meta import SingletonMeta
class GraphGlobalEnv(metaclass=SingletonMeta):
def __init__(self) -> None:
self.graph_building = False
self.graph_node_list = []
self.node_id = -1
def get_node_id(self):
self.node_id += 1
return self.node_id
def add_graph_node(self, node):
self.graph_node_list.append(node)
class GraphContext():
"""
Building the computing graph under the context
>>> with GraphContext():
>>> output = model(colo_input_tensor)
"""
graph_nodes = []
def __enter__(self):
GraphGlobalEnv().graph_building = True
GraphGlobalEnv().graph_node_list = []
def __exit__(self, *exc_info):
GraphGlobalEnv().graph_building = False
GraphGlobalEnv().node_id = -1
self.graph_nodes = GraphGlobalEnv().graph_node_list
class GraphNode(object):
def __init__(self) -> None:
self.prev_nodes = []
self.post_nodes = []
self.id = GraphGlobalEnv().get_node_id()
def add_prev_node(self, node):
if GraphGlobalEnv().graph_building:
self.prev_nodes.append(node)
def add_post_node(self, node):
if GraphGlobalEnv().graph_building:
self.post_nodes.append(node)
def post_node_empty(self) -> bool:
return len(self.post_nodes) == 0
class GraphOpNode(GraphNode):
def __init__(self, op_type, param_list) -> None:
super().__init__()
self._op_type = op_type
self._param_list = param_list
GraphGlobalEnv().add_graph_node(self)
def add_prev_tensor(self, colo_tensor: ColoTensor):
r"""
Link the current graph op node to previous graph op.
Op1 <- Activation (colo_tensor) Op2
Op1 <- Op2
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
prev_ops = colo_tensor._graph_node.prev_nodes
for op_node in prev_ops:
self.add_prev_node(op_node)
op_node.add_post_node(self)
def add_post_tensor(self, colo_tensor: ColoTensor):
"""
Op <- Activation (colo_tensor)
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
colo_tensor._graph_node.add_prev_node(self)
def print(self):
print(
f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}'
)
import functools
import torch
from colossalai.tensor import ColoTensor
from typing import Callable, List
from colossalai.nn._ops._utils import convert_to_colo_tensor
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def register_colo_graph_decorator(func):
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
@functools.wraps(func)
def wrapper(*args, **kwargs):
param_list = []
input_list = []
# TODO(jiaruifang) find the pg
for idx, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and idx in input_pos:
input_list.append(convert_to_colo_tensor(arg))
if isinstance(arg, torch.Tensor) and idx in param_pos:
param_list.append(convert_to_colo_tensor(arg))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', param_list)
# TODO supports a list of ColoTensor as args
if len(input_list) > 0:
cur_op_node.add_prev_tensor(input_list[0])
outputs = func(*args, **kwargs)
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
# TODO supports a list of ColoTensor as args
if isinstance(outputs[0], ColoTensor):
cur_op_node.add_post_tensor(outputs[0])
return outputs
return wrapper
return register_colo_graph_decorator
import math
import inspect
from typing import Callable
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
from ._utils import ColossalaiModule
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D,
'2d': VocabParallelClassifier2D,
'2.5d': VocabParallelClassifier2p5D,
'3d': VocabParallelClassifier3D
}
class Linear(ColossalaiModule):
"""Linear layer of colossalai.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
The ``kwargs`` should contain parameters below:
::
Linear1D:
gather_output: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false)
Linear2D:
skip_bias_add: bool (optional, default to be false)
Linear2p5D:
skip_bias_add: bool (optional, default to be false)
Linear3D:
None
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None:
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features)
if layer.bias is not None:
bias_initializer(layer.bias, fan_in=in_features)
else:
linear_cls = _parallel_linear[tensor_parallel]
gather_output = kwargs.pop('gather_output', None)
if 'gather_output' in inspect.signature(
linear_cls.__init__).parameters.keys(): # gather_out arg is available
kwargs['gather_output'] = gather_output
layer = linear_cls(
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
super().__init__(layer)
class Classifier(ColossalaiModule):
"""Classifier layer of colossalai.
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
vocab_parallel_limit: int = 2048) -> None:
tensor_parallel = get_tensor_parallel_mode()
if num_classes <= vocab_parallel_limit or tensor_parallel is None:
layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
else:
layer = _vocab_parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
super().__init__(layer)
import inspect
import math
from typing import Callable
from torch import dtype, nn
from colossalai.utils import get_current_device
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
from ._utils import ColossalaiModule
_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D,
'2d': VocabParallelClassifier2D,
'2.5d': VocabParallelClassifier2p5D,
'3d': VocabParallelClassifier3D
}
class Linear(ColossalaiModule):
"""Linear layer of colossalai.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
The ``kwargs`` should contain parameters below:
::
Linear1D:
gather_output: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false)
Linear2D:
skip_bias_add: bool (optional, default to be false)
Linear2p5D:
skip_bias_add: bool (optional, default to be false)
Linear3D:
None
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
tensor_parallel = get_tensor_parallel_mode()
linear_cls = _parallel_linear[tensor_parallel]
gather_output = kwargs.pop('gather_output', None)
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
kwargs['gather_output'] = gather_output
layer = linear_cls(
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
super().__init__(layer)
class Classifier(ColossalaiModule):
"""Classifier layer of colossalai.
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
vocab_parallel_limit: int = 2048) -> None:
tensor_parallel = get_tensor_parallel_mode()
if num_classes <= vocab_parallel_limit or tensor_parallel is None:
layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
else:
layer = _vocab_parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
super().__init__(layer)
import torch
import torch.distributed as dist
from torch import Tensor
from typing import Any, Tuple, Optional
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
try:
import colossal_moe_cuda
COL_MOE_KERNEL_FLAG = True
except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
return expert_input
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
return output
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
return colossal_moe_cuda.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
try:
from colossalai._C import moe
except:
moe = None
def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
return expert_input
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
return output
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try:
import fused_mix_prec_layer_norm_cuda
......@@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
......@@ -7,6 +7,9 @@ from typing import Callable, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
......@@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict)
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding
from ._operation import linear_with_async_comm
from ._utils import (
gather_forward_split_backward,
get_parallel_input,
reduce_grad,
reduce_input,
set_parallel_input,
split_forward_gather_backward,
)
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
@LAYERS.register_module
......@@ -59,12 +77,11 @@ class Linear1D(ColossalaiModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
parallel_input = get_parallel_input()
if not parallel_input:
if not parallel_input and not gather_output:
layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
......@@ -96,8 +113,21 @@ class LayerNorm1D(ColossalaiModule):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
......@@ -519,11 +549,12 @@ class Linear1D_Col(ParallelLayer):
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
......@@ -565,9 +596,12 @@ class Linear1D_Row(ParallelLayer):
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
......@@ -585,6 +619,9 @@ class Linear1D_Row(ParallelLayer):
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
......@@ -594,6 +631,9 @@ class Linear1D_Row(ParallelLayer):
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
......@@ -664,9 +704,26 @@ class Linear1D_Row(ParallelLayer):
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=gpc.get_group(ParallelMode.PARALLEL_1D),
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
......
......@@ -4,91 +4,163 @@
from typing import Optional, Tuple
import torch
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env
from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from ._utils import get_parallel_mode_from_env, push_async_grad
class _Linear3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
ctx.use_bias = bias is not None
def forward(
ctx,
input_: Tensor,
weight: Tensor,
weight_id: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
input_ = all_gather(input_, input_dim, input_parallel_mode)
weight = all_gather(weight, weight_dim, weight_parallel_mode)
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, 0, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode)
output = reduce_scatter(output, 0, output_parallel_mode)
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait()
return input_grad, weight_grad, None, None, None, None
def linear_3d(
input_: Tensor,
weight: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""Linear layer for 3D parallelism.
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Linear3D.apply(
input_,
weight,
id(weight),
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Classifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)]
weight = broadcast(weight, src_rank, input_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight.transpose(0, 1))
output = all_reduce(output, output_parallel_mode)
if bias is not None:
ctx.bias_id = bias_id
output += bias
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_dim = input_dim
ctx.weight_dim = weight_dim
ctx.output_dim = output_dim
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
async_ops = list()
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
bias_grad = None
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
def linear_3d(input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
r"""Linear layer for 3D parallelism.
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
else:
weight_grad = None
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight)
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def classifier_3d(
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel classifier.
Args:
input_ (:class:`torch.tensor`): input matrix.
......@@ -97,38 +169,52 @@ def linear_3d(input_: Tensor,
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_dim (int, optional): dimension of input, defaults to 0.
weight_dim (int, optional): dimension of weight, defaults to -1.
output_dim (int, optional): dimension of output, defaults to 0.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
input_dim, weight_dim, output_dim)
return _Classifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Classifier3D(torch.autograd.Function):
class _VocabParallelClassifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
weight = broadcast(weight, src_rank, input_parallel_mode)
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight.transpose(0, 1))
output = all_reduce(output, output_parallel_mode)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, 0, output_parallel_mode)
if bias is not None:
ctx.bias_id = bias_id
output += bias
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
......@@ -138,38 +224,37 @@ class _Classifier3D(torch.autograd.Function):
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
async_ops = list()
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
weight_grad = None
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
bias_grad = None
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_grad = torch.matmul(output_grad, weight)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
for op in async_ops:
if op is not None:
op.wait()
input_op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
r"""3D parallel classifier.
def vocab_parallel_classifier_3d(
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D vocab parallel classifier.
Args:
input_ (:class:`torch.tensor`): input matrix.
......@@ -183,33 +268,72 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
return _VocabParallelClassifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
@torch.jit.script
def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
mu = x - mean
var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps)
z = mu / sigma
output = weight * z + bias
return output, mu, sigma
@torch.jit.script
def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma
dvar = dz * mu * (-0.5) * sigma**(-3)
dmean = -dmu
dvar = torch.sum(dvar, -1, keepdim=True)
dmean = torch.sum(dmean, -1, keepdim=True)
return dmu, dmean, dvar
class _Layernorm3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
sigma = torch.sqrt(var + eps)
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Tensor,
weight_id: int,
bias_id: int,
normalized_shape: int,
eps: float,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.bias_id = bias_id
sum_ = torch.sum(input_, dim=-1, keepdim=True)
sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True)
mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape
output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps)
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z
if bias is not None:
output = output + bias
ctx.use_bias = bias is not None
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
return output
......@@ -217,34 +341,31 @@ class _Layernorm3D(torch.autograd.Function):
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
weight_grad = output_grad * mu / sigma
if ctx.use_bias:
bias_grad = output_grad
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
else:
bias_grad = None
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
if ctx.use_bias:
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / \
ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight)
dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode)
input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
def layernorm_3d(
input_: Tensor,
weight: Tensor,
bias: Tensor,
normalized_shape: int,
eps: float,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel Layernorm.
Args:
......@@ -257,16 +378,24 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
return _Layernorm3D.apply(
input_,
weight,
bias,
id(weight),
id(bias),
normalized_shape,
eps,
output_parallel_mode,
input_x_weight_parallel_mode,
)
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
......@@ -315,17 +444,12 @@ def split_batch_3d(input_: Tensor,
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
dim_size = input_.size(dim)
if input_.size(dim) <= 1:
return input_
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_world_size = gpc.get_world_size(weight_parallel_mode)
input_world_size = gpc.get_world_size(input_parallel_mode)
assert dim_size % (input_world_size*weight_world_size) == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output
......@@ -464,47 +588,3 @@ def reduce_by_batch_3d(tensor: Tensor,
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
r"""broadcast weight from diagonal.
Args:
input_ (:class:`torch.tensor`): input matrix.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
output = broadcast(input_, src_rank, input_parallel_mode)
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
else:
input_grad = None
return input_grad, None, None, None
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from collections import OrderedDict
from functools import partial
import torch
from torch import Tensor
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from torch import Tensor
def get_depth_from_env() -> int:
......@@ -17,30 +21,17 @@ def get_depth_from_env() -> int:
def get_parallel_mode_from_env(group):
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
f'{group} is not valid for 3D tensor parallelism.'
return getattr(env, group)
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
}
res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT
elif res == 'B':
return ParallelMode.PARALLEL_3D_WEIGHT
elif res == 'C':
return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group():
env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
env.input_x_weight_group_3d, env.output_x_weight_group_3d = (
env.output_x_weight_group_3d,
env.input_x_weight_group_3d,
)
def dbg_check_shape(tensor: Tensor, shape: tuple):
......@@ -49,3 +40,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
print(tensor.shape)
assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape)
class AsyncGradientBucket(object):
def __init__(self):
self.bucket = OrderedDict()
def __len__(self):
return len(self.bucket)
def push(self, async_op, grad_tensor, param_id):
self.bucket[param_id] = tuple((async_op, grad_tensor))
return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device)
def pop(self, param_id):
grad = None
if param_id in self.bucket:
op, grad = self.bucket.pop(param_id)
if op is not None:
op.wait()
return grad
def synchronize(self, params):
for p in params:
i = id(p)
if i in self.bucket:
op, grad = self.bucket.pop(i)
if op is not None:
op.wait()
p.grad.add_(grad)
_async_grad_bucket = AsyncGradientBucket()
def push_async_grad(op, grad, param_id):
return _async_grad_bucket.push(op, grad, param_id)
def pop_async_grad(param_id):
return _async_grad_bucket.pop(param_id)
def _async_grad_hook(grad, param_id):
grad.add_(pop_async_grad(param_id))
return grad
def register_async_grad_hook(param):
param.register_hook(partial(_async_grad_hook, param_id=id(param)))
def synchronize(params=list()):
_async_grad_bucket.synchronize(params)
torch.cuda.default_stream().synchronize()
if len(_async_grad_bucket) > 0:
raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.")
......@@ -5,24 +5,36 @@ from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict)
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
from ._operation import (
all_gather_tensor_3d,
classifier_3d,
layernorm_3d,
linear_3d,
reduce_scatter_tensor_3d,
split_batch_3d,
split_tensor_3d,
vocab_parallel_classifier_3d,
)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group
@LAYERS.register_module
......@@ -45,7 +57,8 @@ class LayerNorm3D(ParallelLayer):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
......@@ -58,6 +71,7 @@ class LayerNorm3D(ParallelLayer):
else:
self.bias = None
self.variance_epsilon = eps
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
......@@ -67,8 +81,10 @@ class LayerNorm3D(ParallelLayer):
def reset_parameters(self) -> None:
init.ones_()(self.weight)
register_async_grad_hook(self.weight)
if self.bias is not None:
init.zeros_()(self.bias)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
......@@ -134,8 +150,15 @@ class LayerNorm3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
return layernorm_3d(
input_,
self.weight,
self.bias,
self.normalized_shape,
self.variance_epsilon,
self.output_parallel_mode,
self.input_x_weight_parallel_mode,
)
@LAYERS.register_module
......@@ -161,6 +184,7 @@ class Linear3D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
......@@ -168,10 +192,12 @@ class Linear3D(ParallelLayer):
self.out_features = out_features
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
self.skip_bias_add = skip_bias_add
self.in_features_per_partition = divide(in_features, self.depth**2)
self.out_features_per_partition = divide(out_features, self.depth)
self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter(
......@@ -194,18 +220,23 @@ class Linear3D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
self.bias.register_hook(self._sync_grad_hook)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
......@@ -256,7 +287,7 @@ class Linear3D(ParallelLayer):
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
weight_key: 0,
bias_key: 0
},
partition_states={
......@@ -279,7 +310,7 @@ class Linear3D(ParallelLayer):
local_state,
self.weight_parallel_mode,
dims={
weight_key: -1,
weight_key: 0,
bias_key: 0
},
partition_states={
......@@ -324,8 +355,20 @@ class Linear3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = linear_3d(
input_,
self.weight,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
@LAYERS.register_module
......@@ -360,7 +403,7 @@ class Classifier3D(ParallelLayer):
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
......@@ -386,19 +429,17 @@ class Classifier3D(ParallelLayer):
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
......@@ -468,8 +509,14 @@ class Classifier3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
return classifier_3d(
input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module
......@@ -504,7 +551,8 @@ class VocabParallelClassifier3D(ParallelLayer):
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth**2)
......@@ -544,12 +592,14 @@ class VocabParallelClassifier3D(ParallelLayer):
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
......@@ -668,8 +718,14 @@ class VocabParallelClassifier3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
return vocab_parallel_classifier_3d(
input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module
......@@ -708,12 +764,16 @@ class PatchEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.patch_size = to_2tuple(patch_size)
grid_size = to_2tuple(img_size // patch_size)
num_patches = grid_size[0] * grid_size[1]
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_size = embed_size
embed_size_per_partition = divide(embed_size, self.depth)
embed_size_per_partition = embed_size // self.depth
self.flatten = flatten
self.weight = nn.Parameter(
......@@ -725,7 +785,7 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes()
......@@ -737,8 +797,7 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode)
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
......@@ -749,14 +808,10 @@ class PatchEmbedding3D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0]
broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
self.bias.register_hook(self._sync_grad_hook)
......@@ -850,8 +905,9 @@ class PatchEmbedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
......@@ -906,7 +962,8 @@ class Embedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
......@@ -924,13 +981,18 @@ class Embedding3D(ParallelLayer):
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight,
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
......@@ -981,11 +1043,10 @@ class Embedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
......@@ -1039,7 +1100,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
......
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout,
WrappedDropPath)
from .layers import (
DropPath,
VanillaClassifier,
VanillaLayerNorm,
VanillaLinear,
VanillaPatchEmbedding,
WrappedDropout,
WrappedDropPath,
)
__all__ = [
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath"
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath",
"VanillaLinear"
]
import math
from typing import Callable
import torch
import torch.nn.functional as F
from colossalai.context import seed
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch import nn as nn
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Args:
drop_prob (float, optional): probability of dropping path, defaults 0.0.
training (bool, optional): whether in training progress, defaults False.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
Args:
drop_prob (float, optional): probability of dropping path, defaults None.
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class WrappedDropout(nn.Module):
r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of
1/(1-p) during training. This means that during evaluation the module simply computes an identity function.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
super().__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace
if mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return F.dropout(inputs, self.p, self.training, self.inplace)
def normalfunc(self, inputs):
with seed(self.mode):
return F.dropout(inputs, self.p, self.training, self.inplace)
def forward(self, inputs):
return self.func(inputs)
class WrappedDropPath(nn.Module):
r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager.
Args:
p (float, optional): probability of dropping path, defaults 0.0.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
def __init__(self, p: float = 0., mode=None):
super().__init__()
self.p = p
self.mode = mode
if self.mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return drop_path(inputs, self.p, self.training)
def normalfunc(self, inputs):
with seed(self.mode):
return drop_path(inputs, self.p, self.training)
def forward(self, inputs):
return self.func(inputs)
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
r"""
2D Image to Patch Embedding
Args:
img_size (int): image size.
patch_size (int): patch size.
in_chans (int): number of channels of input image.
embed_size (int): size of embedding.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional):
The initializer of position embedding, defaults to zeros initializer.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor:
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class VanillaClassifier(nn.Module):
r"""Dense linear classifier.
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer):
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)
@LAYERS.register_module
class VanillaLayerNorm(nn.Module):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
super().__init__()
self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
else:
self.bias = None
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
import math
from typing import Callable
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn as nn
from torch.nn.parameter import Parameter
from colossalai.context import seed
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Args:
drop_prob (float, optional): probability of dropping path, defaults 0.0.
training (bool, optional): whether in training progress, defaults False.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
Args:
drop_prob (float, optional): probability of dropping path, defaults None.
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class WrappedDropout(nn.Module):
r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of
1/(1-p) during training. This means that during evaluation the module simply computes an identity function.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
super().__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace
if mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return F.dropout(inputs, self.p, self.training, self.inplace)
def normalfunc(self, inputs):
with seed(self.mode):
return F.dropout(inputs, self.p, self.training, self.inplace)
def forward(self, inputs):
return self.func(inputs)
class WrappedDropPath(nn.Module):
r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager.
Args:
p (float, optional): probability of dropping path, defaults 0.0.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
def __init__(self, p: float = 0., mode=None):
super().__init__()
self.p = p
self.mode = mode
if self.mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return drop_path(inputs, self.p, self.training)
def normalfunc(self, inputs):
with seed(self.mode):
return drop_path(inputs, self.p, self.training)
def forward(self, inputs):
return self.func(inputs)
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
r"""
2D Image to Patch Embedding
Args:
img_size (int): image size.
patch_size (int): patch size.
in_chans (int): number of channels of input image.
embed_size (int): size of embedding.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional):
The initializer of position embedding, defaults to zeros initializer.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor:
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class VanillaClassifier(nn.Module):
r"""Dense linear classifier.
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer):
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)
@LAYERS.register_module
class VanillaLayerNorm(nn.Module):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
super().__init__()
self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
else:
self.bias = None
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
@LAYERS.register_module
class VanillaLinear(nn.Module):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
weight_initializer(self.weight, fan_in=in_features, fan_out=out_features)
if self.bias is not None:
bias_initializer(self.bias, fan_in=in_features)
def forward(self, input: Tensor) -> Tensor:
if not self.skip_bias_add:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(input, self.weight), self.bias
import math
from typing import Optional
import torch
from colossalai.kernel.op_builder import CPUAdamBuilder
from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer
from typing import Optional
@OPTIMIZERS.register_module
......@@ -11,12 +14,12 @@ class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``.
`CPUAdam` requires CUDA extensions which can be built during installation or runtime.
This version of CPU Adam accelates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required.
......@@ -44,7 +47,7 @@ class CPUAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
......@@ -74,10 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
try:
import cpu_adam
except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam')
cpu_adam = CPUAdamBuilder().load()
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def torch_adam_update(self,
......@@ -114,7 +114,7 @@ class CPUAdam(NVMeOptimizer):
data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad()
def step(self, closure=None):
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
......@@ -149,9 +149,10 @@ class CPUAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1)
state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
......
......@@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Currently GPU-only. Requires ColossalAI to be installed via
``pip install .``.
`FusedAdam` requires CUDA extensions which can be built during installation or runtime.
This version of fused Adam implements 2 fusions.
......@@ -20,7 +19,7 @@ class FusedAdam(torch.optim.Optimizer):
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
......@@ -65,10 +64,12 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import colossal_C
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = colossal_C.multi_tensor_adam
self.multi_tensor_adam = fused_optim.multi_tensor_adam
else:
raise RuntimeError('FusedAdam requires cuda extensions')
......@@ -80,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else:
super(FusedAdam, self).zero_grad()
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1):
"""Performs a single optimization step.
Arguments:
......@@ -136,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
group['weight_decay'])
group['weight_decay'], div_scale)
return loss
......@@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires ColossalAI to be installed via
``pip install .``.
`FusedLAMB` requires CUDA extensions which can be built during installation or runtime.
This version of fused LAMB implements 2 fusions.
......@@ -76,13 +75,15 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import colossal_C
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0],
dtype=torch.int,
device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
self.multi_tensor_lamb = fused_optim.multi_tensor_lamb
else:
raise RuntimeError('FusedLAMB requires cuda extensions')
......
......@@ -10,8 +10,7 @@ from colossalai.utils import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only. Requires ColossalAI to be installed via
``pip install .``.
`FusedSGD` requires CUDA extensions which can be built during installation or runtime.
This version of fused SGD implements 2 fusions.
......@@ -20,7 +19,7 @@ class FusedSGD(Optimizer):
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
......@@ -80,12 +79,14 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available:
import colossal_C
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0],
dtype=torch.int,
device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd
self.multi_tensor_sgd = fused_optim.multi_tensor_sgd
else:
raise RuntimeError('FusedSGD requires cuda extensions')
......
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)
from typing import Any, Optional
import torch
from colossalai.utils import multi_tensor_applier
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
from colossalai.registry import OPTIMIZERS
from typing import Optional
from colossalai.utils import multi_tensor_applier
from .nvme_optimizer import NVMeOptimizer
......@@ -11,12 +14,12 @@ class HybridAdam(NVMeOptimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``
`HybriadAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
......@@ -43,7 +46,7 @@ class HybridAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
......@@ -68,24 +71,23 @@ class HybridAdam(NVMeOptimizer):
weight_decay=0,
adamw_mode=True,
nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None):
nvme_offload_dir: Optional[str] = None,
**defaults: Any):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
try:
import cpu_adam
import colossal_C
except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam')
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
# build during runtime if not found
cpu_optim = CPUAdamBuilder().load()
fused_optim = FusedOptimBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.gpu_adam_op = colossal_C.multi_tensor_adam
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad()
def step(self, closure=None):
def step(self, closure=None, div_scale: float = -1):
loss = None
if closure is not None:
with torch.enable_grad():
......@@ -122,7 +124,7 @@ class HybridAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1)
state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
......@@ -142,6 +144,6 @@ class HybridAdam(NVMeOptimizer):
bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
bias_correction, group['weight_decay'])
bias_correction, group['weight_decay'], div_scale)
self._post_step()
return loss
import math
from enum import Enum
from typing import Any, Dict, Set, Tuple
import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict, Tuple, Set
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum):
......@@ -53,9 +58,13 @@ class ZeroOptimizer(ColossalaiOptimizer):
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32):
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0,
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
......@@ -63,11 +72,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
self.chunk16_set.add(chunk_16)
self.__init__optimizer()
......@@ -125,13 +140,49 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self._found_overflow.item() > 0
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
assert c16.l2_norm is not None
if c16.is_gathered:
norm_sqr += c16.l2_norm
else:
# this chunk is sharded, use communication to collect total norm
if c16.torch_pg not in group_to_norm:
group_to_norm[c16.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm
c16.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
norm_sqr += comm_buffer.item()
global_norm = math.sqrt(norm_sqr)
return global_norm
def _get_combined_scale(self):
loss_scale = 1
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
......@@ -144,17 +195,22 @@ class ZeroOptimizer(ColossalaiOptimizer):
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
ret = self.optim.step(*args, **kwargs)
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
self.zero_grad()
self._update_fp16_params()
......@@ -219,6 +275,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
......
from .data_parallel import ColoDDP, ZeroDDP
from .gemini_parallel import GeminiDDP
__all__ = ['ColoDDP', 'ZeroDDP']
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']
import torch
import itertools
import torch.distributed as dist
from collections import OrderedDict
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger
from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from .reducer import Reducer
from .utils import get_static_torch_model
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
......@@ -185,25 +190,16 @@ class ColoDDP(torch.nn.Module):
class ZeroDDP(ColoDDP):
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Example:
>>> model = torch.nn.Linear(20, 1)
>>> placement_policy = 'cuda'
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
>>> model = ZeroDDP(model, gemini_manager)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
"""
......@@ -216,13 +212,24 @@ class ZeroDDP(ColoDDP):
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager)
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
# TODO: get param order and filter unused params
for p in module.parameters():
cpu_offload = self.gemini_manager.policy_name != 'cuda'
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
for p in param_order.generate():
assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False):
......@@ -232,28 +239,40 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
with ParamOpHookManager.use_hooks(self.param_op_hook):
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
......@@ -266,7 +285,9 @@ class ZeroDDP(ColoDDP):
p.grad = None
def _post_backward(self):
assert self.chunk_manager.accessed_mem == 0
if self.chunk_manager.accessed_mem != 0:
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
......@@ -274,12 +295,12 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
......@@ -287,16 +308,21 @@ class ZeroDDP(ColoDDP):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk = self.chunk_manager.get_chunk(p)
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced:
if chunk.is_gathered:
chunk.chunk_total.div_(chunk.pg_size)
chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
......@@ -307,12 +333,10 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
"""
Args:
strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
Returns:
dict:
......@@ -322,7 +346,30 @@ class ZeroDDP(ColoDDP):
>>> module.state_dict().keys()
['bias', 'weight']
"""
if strict:
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
......@@ -336,24 +383,20 @@ class ZeroDDP(ColoDDP):
destination = hook_result
return destination
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_list (_type_): a list of torch.nn.Parameters
only_rank_0 (_type_): _description_
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
......@@ -367,7 +410,25 @@ class ZeroDDP(ColoDDP):
param_to_save_data[tensor] = record_tensor
del temp_chunk
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
......@@ -519,7 +580,7 @@ class ZeroDDP(ColoDDP):
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.chunk_total.copy_(temp_chunk)
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment