Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
from colossalai.tensor import ColoTensor
from colossalai.context.singleton_meta import SingletonMeta
class GraphGlobalEnv(metaclass=SingletonMeta):
def __init__(self) -> None:
self.graph_building = False
self.graph_node_list = []
self.node_id = -1
def get_node_id(self):
self.node_id += 1
return self.node_id
def add_graph_node(self, node):
self.graph_node_list.append(node)
class GraphContext():
"""
Building the computing graph under the context
>>> with GraphContext():
>>> output = model(colo_input_tensor)
"""
graph_nodes = []
def __enter__(self):
GraphGlobalEnv().graph_building = True
GraphGlobalEnv().graph_node_list = []
def __exit__(self, *exc_info):
GraphGlobalEnv().graph_building = False
GraphGlobalEnv().node_id = -1
self.graph_nodes = GraphGlobalEnv().graph_node_list
class GraphNode(object):
def __init__(self) -> None:
self.prev_nodes = []
self.post_nodes = []
self.id = GraphGlobalEnv().get_node_id()
def add_prev_node(self, node):
if GraphGlobalEnv().graph_building:
self.prev_nodes.append(node)
def add_post_node(self, node):
if GraphGlobalEnv().graph_building:
self.post_nodes.append(node)
def post_node_empty(self) -> bool:
return len(self.post_nodes) == 0
class GraphOpNode(GraphNode):
def __init__(self, op_type, param_list) -> None:
super().__init__()
self._op_type = op_type
self._param_list = param_list
GraphGlobalEnv().add_graph_node(self)
def add_prev_tensor(self, colo_tensor: ColoTensor):
r"""
Link the current graph op node to previous graph op.
Op1 <- Activation (colo_tensor) Op2
Op1 <- Op2
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
prev_ops = colo_tensor._graph_node.prev_nodes
for op_node in prev_ops:
self.add_prev_node(op_node)
op_node.add_post_node(self)
def add_post_tensor(self, colo_tensor: ColoTensor):
"""
Op <- Activation (colo_tensor)
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
colo_tensor._graph_node.add_prev_node(self)
def print(self):
print(
f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}'
)
import functools
import torch
from colossalai.tensor import ColoTensor
from typing import Callable, List
from colossalai.nn._ops._utils import convert_to_colo_tensor
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def register_colo_graph_decorator(func):
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
@functools.wraps(func)
def wrapper(*args, **kwargs):
param_list = []
input_list = []
# TODO(jiaruifang) find the pg
for idx, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and idx in input_pos:
input_list.append(convert_to_colo_tensor(arg))
if isinstance(arg, torch.Tensor) and idx in param_pos:
param_list.append(convert_to_colo_tensor(arg))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', param_list)
# TODO supports a list of ColoTensor as args
if len(input_list) > 0:
cur_op_node.add_prev_tensor(input_list[0])
outputs = func(*args, **kwargs)
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
# TODO supports a list of ColoTensor as args
if isinstance(outputs[0], ColoTensor):
cur_op_node.add_post_tensor(outputs[0])
return outputs
return wrapper
return register_colo_graph_decorator
import math import inspect
import inspect import math
from typing import Callable from typing import Callable
from colossalai.utils import get_current_device from torch import dtype, nn
from torch import dtype, nn
from colossalai.utils import get_current_device
from ... import init as init
from ..parallel_1d import * from ... import init as init
from ..parallel_2d import * from ..parallel_1d import *
from ..parallel_2p5d import * from ..parallel_2d import *
from ..parallel_3d import * from ..parallel_2p5d import *
from ..utils import get_tensor_parallel_mode from ..parallel_3d import *
from ..vanilla import * from ..utils import get_tensor_parallel_mode
from ._utils import ColossalaiModule from ..vanilla import *
from ._utils import ColossalaiModule
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier, _parallel_classifier = {
'1d': Classifier1D, None: VanillaClassifier,
'2d': Classifier2D, '1d': Classifier1D,
'2.5d': Classifier2p5D, '2d': Classifier2D,
'3d': Classifier3D '2.5d': Classifier2p5D,
} '3d': Classifier3D
}
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D, _vocab_parallel_classifier = {
'2d': VocabParallelClassifier2D, '1d': VocabParallelClassifier1D,
'2.5d': VocabParallelClassifier2p5D, '2d': VocabParallelClassifier2D,
'3d': VocabParallelClassifier3D '2.5d': VocabParallelClassifier2p5D,
} '3d': VocabParallelClassifier3D
}
class Linear(ColossalaiModule):
"""Linear layer of colossalai. class Linear(ColossalaiModule):
"""Linear layer of colossalai.
Args:
in_features (int): size of each input sample. Args:
out_features (int): size of each output sample. in_features (int): size of each input sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. out_features (int): size of each output sample.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
weight_initializer (:class:`typing.Callable`, optional): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer. weight_initializer (:class:`typing.Callable`, optional):
bias_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer.
The initializer of bias, defaults to xavier uniform initializer. bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
The ``kwargs`` should contain parameters below:
:: The ``kwargs`` should contain parameters below:
::
Linear1D:
gather_output: bool (optional, default to be false) Linear1D:
skip_bias_add: bool (optional, default to be false) gather_output: bool (optional, default to be false)
Linear2D: skip_bias_add: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false) Linear2D:
Linear2p5D: skip_bias_add: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false) Linear2p5D:
Linear3D: skip_bias_add: bool (optional, default to be false)
None Linear3D:
None
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_. More details about ``initializer`` please refer to
""" `init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int, def __init__(self,
out_features: int, in_features: int,
bias: bool = True, out_features: int,
dtype: dtype = None, bias: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), dtype: dtype = None,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
**kwargs) -> None: bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel = get_tensor_parallel_mode() **kwargs) -> None:
if tensor_parallel is None: tensor_parallel = get_tensor_parallel_mode()
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device()) linear_cls = _parallel_linear[tensor_parallel]
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features) gather_output = kwargs.pop('gather_output', None)
if layer.bias is not None: if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
bias_initializer(layer.bias, fan_in=in_features) kwargs['gather_output'] = gather_output
else: layer = linear_cls(
linear_cls = _parallel_linear[tensor_parallel] in_features,
gather_output = kwargs.pop('gather_output', None) out_features,
if 'gather_output' in inspect.signature( bias=bias,
linear_cls.__init__).parameters.keys(): # gather_out arg is available dtype=dtype,
kwargs['gather_output'] = gather_output weight_initializer=weight_initializer,
layer = linear_cls( bias_initializer=bias_initializer,
in_features, **kwargs,
out_features, )
bias=bias, super().__init__(layer)
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer, class Classifier(ColossalaiModule):
**kwargs, """Classifier layer of colossalai.
)
super().__init__(layer) Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
class Classifier(ColossalaiModule): weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
"""Classifier layer of colossalai. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
Args: weight_initializer (:class:`typing.Callable`, optional):
in_features (int): size of each input sample. The initializer of weight, defaults to kaiming uniform initializer.
num_classes (int): number of classes. bias_initializer (:class:`typing.Callable`, optional):
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. The initializer of bias, defaults to xavier uniform initializer.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. More details about ``initializer`` please refer to
weight_initializer (:class:`typing.Callable`, optional): `init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
The initializer of weight, defaults to kaiming uniform initializer. """
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer. def __init__(self,
in_features: int,
More details about ``initializer`` please refer to num_classes: int,
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_. weight: nn.Parameter = None,
""" bias: bool = True,
dtype: dtype = None,
def __init__(self, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
in_features: int, bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
num_classes: int, vocab_parallel_limit: int = 2048) -> None:
weight: nn.Parameter = None, tensor_parallel = get_tensor_parallel_mode()
bias: bool = True, if num_classes <= vocab_parallel_limit or tensor_parallel is None:
dtype: dtype = None, layer = _parallel_classifier[tensor_parallel](
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), in_features,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), num_classes,
vocab_parallel_limit: int = 2048) -> None: weight=weight,
tensor_parallel = get_tensor_parallel_mode() bias=bias,
if num_classes <= vocab_parallel_limit or tensor_parallel is None: dtype=dtype,
layer = _parallel_classifier[tensor_parallel]( weight_initializer=weight_initializer,
in_features, bias_initializer=bias_initializer,
num_classes, )
weight=weight, else:
bias=bias, layer = _vocab_parallel_classifier[tensor_parallel](
dtype=dtype, in_features,
weight_initializer=weight_initializer, num_classes,
bias_initializer=bias_initializer, weight=weight,
) bias=bias,
else: dtype=dtype,
layer = _vocab_parallel_classifier[tensor_parallel]( weight_initializer=weight_initializer,
in_features, bias_initializer=bias_initializer,
num_classes, )
weight=weight, super().__init__(layer)
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
super().__init__(layer)
import torch from typing import Any, Optional, Tuple
import torch.distributed as dist
from torch import Tensor import torch
from typing import Any, Tuple, Optional import torch.distributed as dist
from torch.distributed import ProcessGroup from torch import Tensor
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
try: COL_MOE_KERNEL_FLAG = False
import colossal_moe_cuda
try:
COL_MOE_KERNEL_FLAG = True from colossalai._C import moe
except ImportError: except:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!") moe = None
class AllGather(torch.autograd.Function): def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
@staticmethod global moe
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if moe is None:
if ctx is not None: from colossalai.kernel.op_builder import MOEBuilder
ctx.comm_grp = group moe = MOEBuilder().load()
comm_size = dist.get_world_size(group)
if comm_size == 1: class AllGather(torch.autograd.Function):
return inputs.unsqueeze(0)
@staticmethod
buffer_shape = (comm_size,) + inputs.shape def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) global moe
dist.all_gather(buffer_list, inputs, group=group)
return outputs if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
@staticmethod moe = MOEBuilder().load()
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None if ctx is not None:
ctx.comm_grp = group
class ReduceScatter(torch.autograd.Function): comm_size = dist.get_world_size(group)
if comm_size == 1:
@staticmethod return inputs.unsqueeze(0)
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None: buffer_shape = (comm_size,) + inputs.shape
ctx.comm_grp = group outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
comm_size = dist.get_world_size(group) dist.all_gather(buffer_list, inputs, group=group)
if comm_size == 1: return outputs
return inputs.squeeze(0)
@staticmethod
if not inputs.is_contiguous(): def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
inputs = inputs.contiguous() return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) class ReduceScatter(torch.autograd.Function):
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=group) @staticmethod
return outputs def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
@staticmethod ctx.comm_grp = group
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single if not inputs.is_contiguous():
operation in torch.distributed. inputs = inputs.contiguous()
"""
output_shape = inputs.shape[1:]
@staticmethod outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if ctx is not None: dist.reduce_scatter(outputs, buffer_list, group=group)
ctx.comm_grp = group return outputs
if not inputs.is_contiguous():
inputs = inputs.contiguous() @staticmethod
if dist.get_world_size(group) == 1: def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return inputs return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
@staticmethod operation in torch.distributed.
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: """
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
class MoeDispatch(torch.autograd.Function): if ctx is not None:
ctx.comm_grp = group
@staticmethod if not inputs.is_contiguous():
def forward(ctx, tokens, mask, dest_idx, ec): inputs = inputs.contiguous()
s = tokens.size(0) if dist.get_world_size(group) == 1:
h = tokens.size(1) return inputs
output = torch.empty_like(inputs)
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx) dist.all_to_all_single(output, inputs, group=group)
return output
ctx.save_for_backward(mask, dest_idx)
ctx.s = s @staticmethod
ctx.h = h def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
ctx.ec = ec return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
return expert_input
class MoeDispatch(torch.autograd.Function):
@staticmethod
def backward(ctx, output_grad): @staticmethod
mask, dest_idx = ctx.saved_tensors def forward(ctx, tokens, mask, dest_idx, ec):
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) s = tokens.size(0)
return d_tokens, None, None, None h = tokens.size(1)
# load moe kernel during runtime if not pre-built
class MoeCombine(torch.autograd.Function): build_moe_if_not_prebuilt()
@staticmethod expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32 ctx.save_for_backward(mask, dest_idx)
ctx.s = s
s = logits.size(0) ctx.h = h
e = logits.size(1) ctx.ec = ec
c = ec // e
h = expert_tokens.size(-1) return expert_input
fp16_flag = (expert_tokens.dtype == torch.float16) @staticmethod
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens def backward(ctx, output_grad):
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) mask, dest_idx = ctx.saved_tensors
output = ctokens.to(torch.float16) if fp16_flag else ctokens d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e class MoeCombine(torch.autograd.Function):
ctx.c = c
ctx.h = h @staticmethod
ctx.fp16_flag = fp16_flag def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
return output
s = logits.size(0)
@staticmethod e = logits.size(1)
def backward(ctx, tokens_grad): c = ec // e
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors h = expert_tokens.size(-1)
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ # load moe kernel during runtime if not pre-built
else tokens_grad build_moe_if_not_prebuilt()
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, fp16_flag = (expert_tokens.dtype == torch.float16)
mask, dest_idx) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
return d_expert, d_logits, None, None, None
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
def moe_cumsum(inputs: Tensor): ctx.e = e
dim0 = inputs.size(0) ctx.c = c
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) ctx.h = h
if flag and COL_MOE_KERNEL_FLAG: ctx.fp16_flag = fp16_flag
return colossal_moe_cuda.cumsum_sub_one(inputs)
else: return output
return torch.cumsum(inputs, dim=0) - 1
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
import torch import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try: try:
import fused_mix_prec_layer_norm_cuda import fused_mix_prec_layer_norm_cuda
...@@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): ...@@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
...@@ -7,6 +7,9 @@ from typing import Callable, Tuple ...@@ -7,6 +7,9 @@ from typing import Callable, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
...@@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env ...@@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm from colossalai.kernel import LayerNorm
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, from colossalai.utils.checkpointing import (
partition_tensor_parallel_state_dict) broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding
split_forward_gather_backward) from ._operation import linear_with_async_comm
from ._utils import (
gather_forward_split_backward,
get_parallel_input,
reduce_grad,
reduce_input,
set_parallel_input,
split_forward_gather_backward,
)
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
@LAYERS.register_module @LAYERS.register_module
...@@ -59,12 +77,11 @@ class Linear1D(ColossalaiModule): ...@@ -59,12 +77,11 @@ class Linear1D(ColossalaiModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
parallel_input = get_parallel_input() parallel_input = get_parallel_input()
if not parallel_input: if not parallel_input and not gather_output:
layer = Linear1D_Col(in_features, layer = Linear1D_Col(in_features,
out_features, out_features,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer, weight_initializer=weight_initializer,
bias_initializer=bias_initializer) bias_initializer=bias_initializer)
...@@ -96,8 +113,21 @@ class LayerNorm1D(ColossalaiModule): ...@@ -96,8 +113,21 @@ class LayerNorm1D(ColossalaiModule):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
""" """
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype) if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm) super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_state_dict(self, state_dict, prefix, *args):
...@@ -519,11 +549,12 @@ class Linear1D_Col(ParallelLayer): ...@@ -519,11 +549,12 @@ class Linear1D_Col(ParallelLayer):
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1]) input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias) # output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
...@@ -565,9 +596,12 @@ class Linear1D_Row(ParallelLayer): ...@@ -565,9 +596,12 @@ class Linear1D_Row(ParallelLayer):
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
super().__init__() super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters # Keep input parameters
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
...@@ -585,6 +619,9 @@ class Linear1D_Row(ParallelLayer): ...@@ -585,6 +619,9 @@ class Linear1D_Row(ParallelLayer):
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias: if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else: else:
...@@ -594,6 +631,9 @@ class Linear1D_Row(ParallelLayer): ...@@ -594,6 +631,9 @@ class Linear1D_Row(ParallelLayer):
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(False) set_parallel_input(False)
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
...@@ -664,9 +704,26 @@ class Linear1D_Row(ParallelLayer): ...@@ -664,9 +704,26 @@ class Linear1D_Row(ParallelLayer):
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight) if self.stream_chunk_num > 1:
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=gpc.get_group(ParallelMode.PARALLEL_1D),
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:
output = output + self.bias output = output + self.bias
......
...@@ -4,91 +4,163 @@ ...@@ -4,91 +4,163 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env
from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from ._utils import get_parallel_mode_from_env, push_async_grad
class _Linear3D(torch.autograd.Function): class _Linear3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, def forward(
input_: Tensor, ctx,
weight: Tensor, input_: Tensor,
bias: Optional[Tensor], weight: Tensor,
input_parallel_mode: ParallelMode, weight_id: int,
weight_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
input_dim: int = 0, output_parallel_mode: ParallelMode,
weight_dim: int = -1, ) -> Tensor:
output_dim: int = 0) -> Tensor: ctx.weight_id = weight_id
ctx.use_bias = bias is not None ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
input_ = all_gather(input_, input_dim, input_parallel_mode) input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, weight_dim, weight_parallel_mode) weight = all_gather(weight, 0, weight_parallel_mode)
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight) output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode) output = reduce_scatter(output, 0, output_parallel_mode)
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait()
return input_grad, weight_grad, None, None, None, None
def linear_3d(
input_: Tensor,
weight: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""Linear layer for 3D parallelism.
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Linear3D.apply(
input_,
weight,
id(weight),
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Classifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)]
weight = broadcast(weight, src_rank, input_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight.transpose(0, 1))
output = all_reduce(output, output_parallel_mode)
if bias is not None: if bias is not None:
ctx.bias_id = bias_id
output += bias output += bias
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode ctx.output_parallel_mode = output_parallel_mode
ctx.input_dim = input_dim
ctx.weight_dim = weight_dim
ctx.output_dim = output_dim
return output return output
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors input_, weight = ctx.saved_tensors
with torch.no_grad(): weight_grad = torch.matmul(
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode) output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
async_ops = list() if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1)) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True) else:
async_ops.append(op) weight_grad = None
weight_grad = torch.matmul( if ctx.use_bias:
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True) bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
async_ops.append(op) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
if ctx.use_bias: else:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad = None
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op) input_grad = torch.matmul(output_grad, weight)
else:
bias_grad = None return input_grad, weight_grad, bias_grad, None, None, None, None, None
for op in async_ops:
if op is not None: def classifier_3d(
op.wait() input_: Tensor,
weight: Tensor,
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
def linear_3d(input_: Tensor, output_parallel_mode: ParallelMode,
weight: Tensor, ) -> Tensor:
bias: Optional[Tensor], r"""3D parallel classifier.
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
r"""Linear layer for 3D parallelism.
Args: Args:
input_ (:class:`torch.tensor`): input matrix. input_ (:class:`torch.tensor`): input matrix.
...@@ -97,38 +169,52 @@ def linear_3d(input_: Tensor, ...@@ -97,38 +169,52 @@ def linear_3d(input_: Tensor,
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode. input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode. weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_dim (int, optional): dimension of input, defaults to 0.
weight_dim (int, optional): dimension of weight, defaults to -1.
output_dim (int, optional): dimension of output, defaults to 0.
Note: Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
""" """
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode, return _Classifier3D.apply(
input_dim, weight_dim, output_dim) input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Classifier3D(torch.autograd.Function): class _VocabParallelClassifier3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, def forward(
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.weight_id = weight_id
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode) input_ = all_gather(input_, 0, input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)] weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
weight = broadcast(weight, src_rank, input_parallel_mode)
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight.transpose(0, 1)) output = torch.matmul(input_, weight)
output = all_reduce(output, output_parallel_mode) output = reduce_scatter(output, 0, output_parallel_mode)
if bias is not None: if bias is not None:
ctx.bias_id = bias_id
output += bias output += bias
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode ctx.output_parallel_mode = output_parallel_mode
...@@ -138,38 +224,37 @@ class _Classifier3D(torch.autograd.Function): ...@@ -138,38 +224,37 @@ class _Classifier3D(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors input_, weight = ctx.saved_tensors
with torch.no_grad(): output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
async_ops = list()
weight_grad = torch.matmul( input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
weight_grad = None
if ctx.use_bias: weight_grad = torch.matmul(
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
async_ops.append(op)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight) if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
for op in async_ops: input_op.wait()
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, def vocab_parallel_classifier_3d(
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor: input_: Tensor,
r"""3D parallel classifier. weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D vocab parallel classifier.
Args: Args:
input_ (:class:`torch.tensor`): input matrix. input_ (:class:`torch.tensor`): input matrix.
...@@ -183,33 +268,72 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_ ...@@ -183,33 +268,72 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
""" """
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode) return _VocabParallelClassifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
@torch.jit.script
def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
mu = x - mean
var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps)
z = mu / sigma
output = weight * z + bias
return output, mu, sigma
@torch.jit.script
def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma
dvar = dz * mu * (-0.5) * sigma**(-3)
dmean = -dmu
dvar = torch.sum(dvar, -1, keepdim=True)
dmean = torch.sum(dmean, -1, keepdim=True)
return dmu, dmean, dvar
class _Layernorm3D(torch.autograd.Function): class _Layernorm3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, def forward(
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, ctx,
output_parallel_mode: ParallelMode) -> Tensor: input_: Tensor,
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape weight: Tensor,
mu = input_ - mean bias: Tensor,
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape weight_id: int,
sigma = torch.sqrt(var + eps) bias_id: int,
normalized_shape: int,
eps: float,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.bias_id = bias_id
sum_ = torch.sum(input_, dim=-1, keepdim=True)
sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True)
mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape
output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps)
ctx.save_for_backward(mu, sigma, weight) ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z
if bias is not None:
output = output + bias
ctx.use_bias = bias is not None
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode ctx.output_parallel_mode = output_parallel_mode
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
return output return output
...@@ -217,34 +341,31 @@ class _Layernorm3D(torch.autograd.Function): ...@@ -217,34 +341,31 @@ class _Layernorm3D(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
weight_grad = output_grad * mu / sigma
if ctx.use_bias:
bias_grad = output_grad
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
else:
bias_grad = None
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
if ctx.use_bias:
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / \
ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float, bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
output_parallel_mode: ParallelMode) -> Tensor: weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight)
dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode)
input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
def layernorm_3d(
input_: Tensor,
weight: Tensor,
bias: Tensor,
normalized_shape: int,
eps: float,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel Layernorm. r"""3D parallel Layernorm.
Args: Args:
...@@ -257,16 +378,24 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali ...@@ -257,16 +378,24 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability eps (float): a value added to the denominator for numerical stability
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode. output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
Note: Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
""" """
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode, return _Layernorm3D.apply(
output_parallel_mode) input_,
weight,
bias,
id(weight),
id(bias),
normalized_shape,
eps,
output_parallel_mode,
input_x_weight_parallel_mode,
)
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
...@@ -315,17 +444,12 @@ def split_batch_3d(input_: Tensor, ...@@ -315,17 +444,12 @@ def split_batch_3d(input_: Tensor,
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_. in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
dim_size = input_.size(dim) if input_.size(dim) <= 1:
return input_
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_world_size = gpc.get_world_size(weight_parallel_mode) weight_world_size = gpc.get_world_size(weight_parallel_mode)
input_world_size = gpc.get_world_size(input_parallel_mode) input_world_size = gpc.get_world_size(input_parallel_mode)
assert dim_size % (input_world_size*weight_world_size) == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous() output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output return output
...@@ -464,47 +588,3 @@ def reduce_by_batch_3d(tensor: Tensor, ...@@ -464,47 +588,3 @@ def reduce_by_batch_3d(tensor: Tensor,
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
""" """
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
r"""broadcast weight from diagonal.
Args:
input_ (:class:`torch.tensor`): input matrix.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
output = broadcast(input_, src_rank, input_parallel_mode)
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
else:
input_grad = None
return input_grad, None, None, None
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from collections import OrderedDict
from colossalai.context.parallel_mode import ParallelMode from functools import partial
import torch
from torch import Tensor
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from torch import Tensor
def get_depth_from_env() -> int: def get_depth_from_env() -> int:
...@@ -17,30 +21,17 @@ def get_depth_from_env() -> int: ...@@ -17,30 +21,17 @@ def get_depth_from_env() -> int:
def get_parallel_mode_from_env(group): def get_parallel_mode_from_env(group):
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \ assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
f'{group} is not valid for 3D tensor parallelism.' f'{group} is not valid for 3D tensor parallelism.'
return getattr(env, group) return getattr(env, group)
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
}
res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT
elif res == 'B':
return ParallelMode.PARALLEL_3D_WEIGHT
elif res == 'C':
return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group(): def swap_in_out_group():
env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
env.input_x_weight_group_3d, env.output_x_weight_group_3d = (
env.output_x_weight_group_3d,
env.input_x_weight_group_3d,
)
def dbg_check_shape(tensor: Tensor, shape: tuple): def dbg_check_shape(tensor: Tensor, shape: tuple):
...@@ -49,3 +40,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple): ...@@ -49,3 +40,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
print(tensor.shape) print(tensor.shape)
assert tensor.shape == shape, \ assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape) '{} does not match {}'.format(tensor.shape, shape)
class AsyncGradientBucket(object):
def __init__(self):
self.bucket = OrderedDict()
def __len__(self):
return len(self.bucket)
def push(self, async_op, grad_tensor, param_id):
self.bucket[param_id] = tuple((async_op, grad_tensor))
return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device)
def pop(self, param_id):
grad = None
if param_id in self.bucket:
op, grad = self.bucket.pop(param_id)
if op is not None:
op.wait()
return grad
def synchronize(self, params):
for p in params:
i = id(p)
if i in self.bucket:
op, grad = self.bucket.pop(i)
if op is not None:
op.wait()
p.grad.add_(grad)
_async_grad_bucket = AsyncGradientBucket()
def push_async_grad(op, grad, param_id):
return _async_grad_bucket.push(op, grad, param_id)
def pop_async_grad(param_id):
return _async_grad_bucket.pop(param_id)
def _async_grad_hook(grad, param_id):
grad.add_(pop_async_grad(param_id))
return grad
def register_async_grad_hook(param):
param.register_hook(partial(_async_grad_hook, param_id=id(param)))
def synchronize(params=list()):
_async_grad_bucket.synchronize(params)
torch.cuda.default_stream().synchronize()
if len(_async_grad_bucket) > 0:
raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.")
...@@ -5,24 +5,36 @@ from typing import Callable ...@@ -5,24 +5,36 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import all_reduce, broadcast from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, from colossalai.utils.checkpointing import (
partition_tensor_parallel_state_dict) broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d, from ._operation import (
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d) all_gather_tensor_3d,
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group classifier_3d,
layernorm_3d,
linear_3d,
reduce_scatter_tensor_3d,
split_batch_3d,
split_tensor_3d,
vocab_parallel_classifier_3d,
)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group
@LAYERS.register_module @LAYERS.register_module
...@@ -45,7 +57,8 @@ class LayerNorm3D(ParallelLayer): ...@@ -45,7 +57,8 @@ class LayerNorm3D(ParallelLayer):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
...@@ -58,6 +71,7 @@ class LayerNorm3D(ParallelLayer): ...@@ -58,6 +71,7 @@ class LayerNorm3D(ParallelLayer):
else: else:
self.bias = None self.bias = None
self.variance_epsilon = eps self.variance_epsilon = eps
self.reset_parameters()
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None: def _set_tensor_parallel_attributes(self) -> None:
...@@ -67,8 +81,10 @@ class LayerNorm3D(ParallelLayer): ...@@ -67,8 +81,10 @@ class LayerNorm3D(ParallelLayer):
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
init.ones_()(self.weight) init.ones_()(self.weight)
register_async_grad_hook(self.weight)
if self.bias is not None: if self.bias is not None:
init.zeros_()(self.bias) init.zeros_()(self.bias)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
...@@ -134,8 +150,15 @@ class LayerNorm3D(ParallelLayer): ...@@ -134,8 +150,15 @@ class LayerNorm3D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, return layernorm_3d(
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) input_,
self.weight,
self.bias,
self.normalized_shape,
self.variance_epsilon,
self.output_parallel_mode,
self.input_x_weight_parallel_mode,
)
@LAYERS.register_module @LAYERS.register_module
...@@ -161,6 +184,7 @@ class Linear3D(ParallelLayer): ...@@ -161,6 +184,7 @@ class Linear3D(ParallelLayer):
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
...@@ -168,10 +192,12 @@ class Linear3D(ParallelLayer): ...@@ -168,10 +192,12 @@ class Linear3D(ParallelLayer):
self.out_features = out_features self.out_features = out_features
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth) self.skip_bias_add = skip_bias_add
self.out_features_per_partition = divide(out_features, self.depth**2) self.in_features_per_partition = divide(in_features, self.depth**2)
self.out_features_per_partition = divide(out_features, self.depth)
self.bias_features_per_partition = divide(out_features, self.depth) self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter( self.weight = Parameter(
...@@ -194,18 +220,23 @@ class Linear3D(ParallelLayer): ...@@ -194,18 +220,23 @@ class Linear3D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth) set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] broadcast(self.bias,
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) self.output_x_weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) self.bias.register_hook(self._sync_grad_hook)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
...@@ -256,7 +287,7 @@ class Linear3D(ParallelLayer): ...@@ -256,7 +287,7 @@ class Linear3D(ParallelLayer):
local_state, local_state,
self.weight_parallel_mode, self.weight_parallel_mode,
dims={ dims={
weight_key: -1, weight_key: 0,
bias_key: 0 bias_key: 0
}, },
partition_states={ partition_states={
...@@ -279,7 +310,7 @@ class Linear3D(ParallelLayer): ...@@ -279,7 +310,7 @@ class Linear3D(ParallelLayer):
local_state, local_state,
self.weight_parallel_mode, self.weight_parallel_mode,
dims={ dims={
weight_key: -1, weight_key: 0,
bias_key: 0 bias_key: 0
}, },
partition_states={ partition_states={
...@@ -324,8 +355,20 @@ class Linear3D(ParallelLayer): ...@@ -324,8 +355,20 @@ class Linear3D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, output = linear_3d(
self.output_parallel_mode) input_,
self.weight,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
@LAYERS.register_module @LAYERS.register_module
...@@ -360,7 +403,7 @@ class Classifier3D(ParallelLayer): ...@@ -360,7 +403,7 @@ class Classifier3D(ParallelLayer):
self.num_classes = num_classes self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth) self.in_features_per_partition = divide(in_features, self.depth)
...@@ -386,19 +429,17 @@ class Classifier3D(ParallelLayer): ...@@ -386,19 +429,17 @@ class Classifier3D(ParallelLayer):
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
if self.has_weight: if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode)
register_async_grad_hook(self.weight)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) register_async_grad_hook(self.bias)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
...@@ -468,8 +509,14 @@ class Classifier3D(ParallelLayer): ...@@ -468,8 +509,14 @@ class Classifier3D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, return classifier_3d(
self.output_parallel_mode) input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module @LAYERS.register_module
...@@ -504,7 +551,8 @@ class VocabParallelClassifier3D(ParallelLayer): ...@@ -504,7 +551,8 @@ class VocabParallelClassifier3D(ParallelLayer):
self.num_classes = num_classes self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth) self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth**2) self.out_features_per_partition = divide(num_classes, self.depth**2)
...@@ -544,12 +592,14 @@ class VocabParallelClassifier3D(ParallelLayer): ...@@ -544,12 +592,14 @@ class VocabParallelClassifier3D(ParallelLayer):
if self.has_weight: if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] broadcast(self.bias,
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0] gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) self.output_x_weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode) register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict() local_state = OrderedDict()
...@@ -668,8 +718,14 @@ class VocabParallelClassifier3D(ParallelLayer): ...@@ -668,8 +718,14 @@ class VocabParallelClassifier3D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode, return vocab_parallel_classifier_3d(
self.weight_parallel_mode, self.output_parallel_mode) input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module @LAYERS.register_module
...@@ -708,12 +764,16 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -708,12 +764,16 @@ class PatchEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.patch_size = to_2tuple(patch_size) self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
grid_size = to_2tuple(img_size // patch_size) img_size = to_2tuple(img_size)
num_patches = grid_size[0] * grid_size[1] patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_size = embed_size self.embed_size = embed_size
embed_size_per_partition = divide(embed_size, self.depth) embed_size_per_partition = embed_size // self.depth
self.flatten = flatten self.flatten = flatten
self.weight = nn.Parameter( self.weight = nn.Parameter(
...@@ -725,7 +785,7 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -725,7 +785,7 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
...@@ -737,8 +797,7 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -737,8 +797,7 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> Tensor: def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_parallel_mode) grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode)
return grad return grad
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
...@@ -749,14 +808,10 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -749,14 +808,10 @@ class PatchEmbedding3D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0] broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
self.weight.register_hook(self._sync_grad_hook) self.weight.register_hook(self._sync_grad_hook)
self.bias.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook)
...@@ -850,8 +905,9 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -850,8 +905,9 @@ class PatchEmbedding3D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_batch_3d(input_,
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
...@@ -906,7 +962,8 @@ class Embedding3D(ParallelLayer): ...@@ -906,7 +962,8 @@ class Embedding3D(ParallelLayer):
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embed_dim = embedding_dim
...@@ -924,13 +981,18 @@ class Embedding3D(ParallelLayer): ...@@ -924,13 +981,18 @@ class Embedding3D(ParallelLayer):
def _set_tensor_parallel_attributes(self) -> None: def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth) set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0] broadcast(self.weight,
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None:
...@@ -981,11 +1043,10 @@ class Embedding3D(ParallelLayer): ...@@ -981,11 +1043,10 @@ class Embedding3D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode) input_ = split_batch_3d(input_,
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) input_parallel_mode=self.input_parallel_mode,
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode, weight_parallel_mode=self.weight_parallel_mode)
self.output_parallel_mode) output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
...@@ -1039,7 +1100,7 @@ class VocabParallelEmbedding3D(ParallelLayer): ...@@ -1039,7 +1100,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode) self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2) self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth) self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode) vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
......
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout, from .layers import (
WrappedDropPath) DropPath,
VanillaClassifier,
VanillaLayerNorm,
VanillaLinear,
VanillaPatchEmbedding,
WrappedDropout,
WrappedDropPath,
)
__all__ = [ __all__ = [
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath" "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath",
"VanillaLinear"
] ]
import math import math
from typing import Callable from typing import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.context import seed from torch import Tensor
from colossalai.nn import init as init from torch import nn as nn
from colossalai.registry import LAYERS from torch.nn.parameter import Parameter
from colossalai.utils.cuda import get_current_device
from torch import Tensor from colossalai.context import seed
from torch import nn as nn from colossalai.nn import init as init
from colossalai.registry import LAYERS
from ..utils import to_2tuple from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
def drop_path(x, drop_prob: float = 0., training: bool = False):
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
Args: 'survival rate' as the argument.
drop_prob (float, optional): probability of dropping path, defaults 0.0.
training (bool, optional): whether in training progress, defaults False. Args:
""" drop_prob (float, optional): probability of dropping path, defaults 0.0.
if drop_prob == 0. or not training: training (bool, optional): whether in training progress, defaults False.
return x """
keep_prob = 1 - drop_prob if drop_prob == 0. or not training:
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets return x
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) keep_prob = 1 - drop_prob
random_tensor.floor_() # binarize shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
output = x.div(keep_prob) * random_tensor random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
return output random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). class DropPath(nn.Module):
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args: Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
drop_prob (float, optional): probability of dropping path, defaults None.
""" Args:
drop_prob (float, optional): probability of dropping path, defaults None.
def __init__(self, drop_prob=None): """
super(DropPath, self).__init__()
self.drop_prob = drop_prob def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
def forward(self, x): self.drop_prob = drop_prob
return drop_path(x, self.drop_prob, self.training)
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class WrappedDropout(nn.Module):
r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each class WrappedDropout(nn.Module):
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
1/(1-p) during training. This means that during evaluation the module simply computes an identity function. some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of
Args: 1/(1-p) during training. This means that during evaluation the module simply computes an identity function.
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False. Args:
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
Note: mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ Note:
""" The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): """
super().__init__()
if p < 0 or p > 1: def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
raise ValueError("dropout probability has to be between 0 and 1, " super().__init__()
"but got {}".format(p)) if p < 0 or p > 1:
self.p = p raise ValueError("dropout probability has to be between 0 and 1, "
self.inplace = inplace "but got {}".format(p))
if mode is None: self.p = p
self.func = self.nonefunc self.inplace = inplace
else: if mode is None:
self.func = self.normalfunc self.func = self.nonefunc
self.mode = mode else:
self.func = self.normalfunc
def nonefunc(self, inputs): self.mode = mode
return F.dropout(inputs, self.p, self.training, self.inplace)
def nonefunc(self, inputs):
def normalfunc(self, inputs): return F.dropout(inputs, self.p, self.training, self.inplace)
with seed(self.mode):
return F.dropout(inputs, self.p, self.training, self.inplace) def normalfunc(self, inputs):
with seed(self.mode):
def forward(self, inputs): return F.dropout(inputs, self.p, self.training, self.inplace)
return self.func(inputs)
def forward(self, inputs):
return self.func(inputs)
class WrappedDropPath(nn.Module):
r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager. class WrappedDropPath(nn.Module):
r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args: Here, it is wrapped with the context of seed manager.
p (float, optional): probability of dropping path, defaults 0.0.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. Args:
p (float, optional): probability of dropping path, defaults 0.0.
Note: mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ Note:
""" The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
def __init__(self, p: float = 0., mode=None): """
super().__init__()
self.p = p def __init__(self, p: float = 0., mode=None):
self.mode = mode super().__init__()
if self.mode is None: self.p = p
self.func = self.nonefunc self.mode = mode
else: if self.mode is None:
self.func = self.normalfunc self.func = self.nonefunc
self.mode = mode else:
self.func = self.normalfunc
def nonefunc(self, inputs): self.mode = mode
return drop_path(inputs, self.p, self.training)
def nonefunc(self, inputs):
def normalfunc(self, inputs): return drop_path(inputs, self.p, self.training)
with seed(self.mode):
return drop_path(inputs, self.p, self.training) def normalfunc(self, inputs):
with seed(self.mode):
def forward(self, inputs): return drop_path(inputs, self.p, self.training)
return self.func(inputs)
def forward(self, inputs):
return self.func(inputs)
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
r""" @LAYERS.register_module
2D Image to Patch Embedding class VanillaPatchEmbedding(nn.Module):
r"""
Args: 2D Image to Patch Embedding
img_size (int): image size.
patch_size (int): patch size. Args:
in_chans (int): number of channels of input image. img_size (int): image size.
embed_size (int): size of embedding. patch_size (int): patch size.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. in_chans (int): number of channels of input image.
flatten (bool, optional): whether to flatten output tensor, defaults to True. embed_size (int): size of embedding.
weight_initializer (:class:`typing.Callable`, optional): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer. flatten (bool, optional): whether to flatten output tensor, defaults to True.
bias_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional): bias_initializer (:class:`typing.Callable`, optional):
The initializer of position embedding, defaults to zeros initializer. The initializer of bias, defaults to xavier uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional):
More details about initializer please refer to The initializer of position embedding, defaults to zeros initializer.
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
""" More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def __init__(self, """
img_size: int,
patch_size: int, def __init__(self,
in_chans: int, img_size: int,
embed_size: int, patch_size: int,
flatten: bool = True, in_chans: int,
dtype: torch.dtype = None, embed_size: int,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), flatten: bool = True,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), dtype: torch.dtype = None,
position_embed_initializer: Callable = init.zeros_()): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
super().__init__() bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
img_size = to_2tuple(img_size) position_embed_initializer: Callable = init.zeros_()):
patch_size = to_2tuple(patch_size) super().__init__()
self.img_size = img_size img_size = to_2tuple(img_size)
self.patch_size = patch_size patch_size = to_2tuple(patch_size)
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size
self.num_patches = self.grid_size[0] * self.grid_size[1] self.patch_size = patch_size
self.flatten = flatten self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.weight = nn.Parameter( self.flatten = flatten
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) self.weight = nn.Parameter(
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter( self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
bias_initializer(self.bias, fan_in=fan_in) fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
position_embed_initializer(self.pos_embed) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tensor: position_embed_initializer(self.pos_embed)
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \ def forward(self, input_: Tensor) -> Tensor:
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." B, C, H, W = input_.shape
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) assert H == self.img_size[0] and W == self.img_size[1], \
if self.flatten: f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed cls_token = self.cls_token.expand(output.shape[0], -1, -1)
return output output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class VanillaClassifier(nn.Module):
r"""Dense linear classifier. @LAYERS.register_module
class VanillaClassifier(nn.Module):
Args: r"""Dense linear classifier.
in_features (int): size of each input sample.
num_classes (int): number of classes. Args:
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. in_features (int): size of each input sample.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. num_classes (int): number of classes.
flatten (bool, optional): whether to flatten output tensor, defaults to True. weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
weight_initializer (:class:`typing.Callable`, optional): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer. flatten (bool, optional): whether to flatten output tensor, defaults to True.
bias_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
More details about initializer please refer to The initializer of bias, defaults to xavier uniform initializer.
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
""" More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def __init__(self, """
in_features: int,
num_classes: int, def __init__(self,
weight: nn.Parameter = None, in_features: int,
bias: bool = True, num_classes: int,
dtype: torch.dtype = None, weight: nn.Parameter = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias: bool = True,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): dtype: torch.dtype = None,
super().__init__() weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
self.in_features = in_features bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
self.num_classes = num_classes super().__init__()
self.in_features = in_features
if weight is not None: self.num_classes = num_classes
self.weight = weight
self.has_weight = False if weight is not None:
else: self.weight = weight
self.weight = nn.Parameter( self.has_weight = False
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) else:
self.has_weight = True self.weight = nn.Parameter(
if bias: torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.has_weight = True
else: if bias:
self.bias = None self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.reset_parameters(weight_initializer, bias_initializer) self.bias = None
def reset_parameters(self, weight_initializer, bias_initializer): self.reset_parameters(weight_initializer, bias_initializer)
fan_in, fan_out = self.in_features, self.num_classes
def reset_parameters(self, weight_initializer, bias_initializer):
if self.has_weight: fan_in, fan_out = self.in_features, self.num_classes
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.has_weight:
if self.bias is not None: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
if self.bias is not None:
def forward(self, input_: Tensor) -> Tensor: bias_initializer(self.bias, fan_in=fan_in)
return F.linear(input_, self.weight, self.bias)
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)
@LAYERS.register_module
class VanillaLayerNorm(nn.Module):
r""" @LAYERS.register_module
Layer Normalization for colossalai class VanillaLayerNorm(nn.Module):
r"""
Args: Layer Normalization for colossalai
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] Args:
\times \ldots \times \text{normalized_shape}[-1]]` normalized_shape (int): input shape from an expected input of size.
If a single integer is used, it is treated as a singleton list, and this module will :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
normalize over the last dimension which is expected to be of that specific size. \times \ldots \times \text{normalized_shape}[-1]]`
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. If a single integer is used, it is treated as a singleton list, and this module will
bias (bool, optional): Whether to add a bias, defaults to ``True``. normalize over the last dimension which is expected to be of that specific size.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
""" bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): """
super().__init__()
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
self.normalized_shape = (normalized_shape,) super().__init__()
self.variance_epsilon = eps
self.normalized_shape = (normalized_shape,)
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if bias:
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
else: if bias:
self.bias = None self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
else:
def forward(self, x: Tensor) -> Tensor: self.bias = None
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
@LAYERS.register_module
class VanillaLinear(nn.Module):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
weight_initializer(self.weight, fan_in=in_features, fan_out=out_features)
if self.bias is not None:
bias_initializer(self.bias, fan_in=in_features)
def forward(self, input: Tensor) -> Tensor:
if not self.skip_bias_add:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(input, self.weight), self.bias
import math import math
from typing import Optional
import torch import torch
from colossalai.kernel.op_builder import CPUAdamBuilder
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
from typing import Optional
@OPTIMIZERS.register_module @OPTIMIZERS.register_module
...@@ -11,12 +14,12 @@ class CPUAdam(NVMeOptimizer): ...@@ -11,12 +14,12 @@ class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters. Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device: But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed. * Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed. * Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``. `CPUAdam` requires CUDA extensions which can be built during installation or runtime.
This version of CPU Adam accelates parameters updating on CPU with SIMD. This version of CPU Adam accelates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required. Support of AVX2 or AVX512 is required.
...@@ -44,7 +47,7 @@ class CPUAdam(NVMeOptimizer): ...@@ -44,7 +47,7 @@ class CPUAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam! (default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True) True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False) accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
...@@ -74,10 +77,7 @@ class CPUAdam(NVMeOptimizer): ...@@ -74,10 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
try: cpu_adam = CPUAdamBuilder().load()
import cpu_adam
except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam')
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def torch_adam_update(self, def torch_adam_update(self,
...@@ -114,7 +114,7 @@ class CPUAdam(NVMeOptimizer): ...@@ -114,7 +114,7 @@ class CPUAdam(NVMeOptimizer):
data.addcdiv_(exp_avg, denom, value=-step_size) data.addcdiv_(exp_avg, denom, value=-step_size)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None, div_scale: float = -1):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
...@@ -149,9 +149,10 @@ class CPUAdam(NVMeOptimizer): ...@@ -149,9 +149,10 @@ class CPUAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq') self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1) state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq') self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
......
...@@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier ...@@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer): class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
Currently GPU-only. Requires ColossalAI to be installed via `FusedAdam` requires CUDA extensions which can be built during installation or runtime.
``pip install .``.
This version of fused Adam implements 2 fusions. This version of fused Adam implements 2 fusions.
...@@ -20,7 +19,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -20,7 +19,7 @@ class FusedAdam(torch.optim.Optimizer):
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False`` or ``torch.optim.Adam`` with ``adamw_mode=False``
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
...@@ -65,10 +64,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -65,10 +64,12 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0 self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
if multi_tensor_applier.available: if multi_tensor_applier.available:
import colossal_C from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = colossal_C.multi_tensor_adam self.multi_tensor_adam = fused_optim.multi_tensor_adam
else: else:
raise RuntimeError('FusedAdam requires cuda extensions') raise RuntimeError('FusedAdam requires cuda extensions')
...@@ -80,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -80,7 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else: else:
super(FusedAdam, self).zero_grad() super(FusedAdam, self).zero_grad()
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale: float = -1):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -136,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -136,6 +137,6 @@ class FusedAdam(torch.optim.Optimizer):
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction,
group['weight_decay']) group['weight_decay'], div_scale)
return loss return loss
...@@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier ...@@ -9,8 +9,7 @@ from colossalai.utils import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer): class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm. """Implements LAMB algorithm.
Currently GPU-only. Requires ColossalAI to be installed via `FusedLAMB` requires CUDA extensions which can be built during installation or runtime.
``pip install .``.
This version of fused LAMB implements 2 fusions. This version of fused LAMB implements 2 fusions.
...@@ -76,13 +75,15 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -76,13 +75,15 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
import colossal_C from colossalai.kernel.op_builder import FusedOptimBuilder
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm fused_optim = FusedOptimBuilder().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor([0], self._dummy_overflow_buf = torch.tensor([0],
dtype=torch.int, dtype=torch.int,
device=self.param_groups[0]["params"][0].device) device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb self.multi_tensor_lamb = fused_optim.multi_tensor_lamb
else: else:
raise RuntimeError('FusedLAMB requires cuda extensions') raise RuntimeError('FusedLAMB requires cuda extensions')
......
...@@ -10,8 +10,7 @@ from colossalai.utils import multi_tensor_applier ...@@ -10,8 +10,7 @@ from colossalai.utils import multi_tensor_applier
class FusedSGD(Optimizer): class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum). r"""Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only. Requires ColossalAI to be installed via `FusedSGD` requires CUDA extensions which can be built during installation or runtime.
``pip install .``.
This version of fused SGD implements 2 fusions. This version of fused SGD implements 2 fusions.
...@@ -20,7 +19,7 @@ class FusedSGD(Optimizer): ...@@ -20,7 +19,7 @@ class FusedSGD(Optimizer):
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD`` :class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
Nesterov momentum is based on the formula from Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__. `On the importance of initialization and momentum in deep learning`__.
...@@ -80,12 +79,14 @@ class FusedSGD(Optimizer): ...@@ -80,12 +79,14 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available: if multi_tensor_applier.available:
import colossal_C from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor([0], self._dummy_overflow_buf = torch.tensor([0],
dtype=torch.int, dtype=torch.int,
device=self.param_groups[0]["params"][0].device) device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd self.multi_tensor_sgd = fused_optim.multi_tensor_sgd
else: else:
raise RuntimeError('FusedSGD requires cuda extensions') raise RuntimeError('FusedSGD requires cuda extensions')
......
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)
from typing import Any, Optional
import torch import torch
from colossalai.utils import multi_tensor_applier from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from typing import Optional from colossalai.utils import multi_tensor_applier
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
...@@ -11,12 +14,12 @@ class HybridAdam(NVMeOptimizer): ...@@ -11,12 +14,12 @@ class HybridAdam(NVMeOptimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters. Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device: But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed. * Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed. * Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .`` `HybriadAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
...@@ -43,7 +46,7 @@ class HybridAdam(NVMeOptimizer): ...@@ -43,7 +46,7 @@ class HybridAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam! (default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True) True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False) accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
...@@ -68,24 +71,23 @@ class HybridAdam(NVMeOptimizer): ...@@ -68,24 +71,23 @@ class HybridAdam(NVMeOptimizer):
weight_decay=0, weight_decay=0,
adamw_mode=True, adamw_mode=True,
nvme_offload_fraction: float = 0.0, nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None): nvme_offload_dir: Optional[str] = None,
**defaults: Any):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
try:
import cpu_adam
import colossal_C
except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam')
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) # build during runtime if not found
cpu_optim = CPUAdamBuilder().load()
fused_optim = FusedOptimBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.gpu_adam_op = colossal_C.multi_tensor_adam self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None, div_scale: float = -1):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
...@@ -122,7 +124,7 @@ class HybridAdam(NVMeOptimizer): ...@@ -122,7 +124,7 @@ class HybridAdam(NVMeOptimizer):
self._pre_update(p, 'exp_avg', 'exp_avg_sq') self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], -1) state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq') self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
...@@ -142,6 +144,6 @@ class HybridAdam(NVMeOptimizer): ...@@ -142,6 +144,6 @@ class HybridAdam(NVMeOptimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'],
group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode,
bias_correction, group['weight_decay']) bias_correction, group['weight_decay'], div_scale)
self._post_step() self._post_step()
return loss return loss
import math
from enum import Enum
from typing import Any, Dict, Set, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP from torch.optim import Optimizer
from typing import Dict, Tuple, Set
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum): class OptimState(Enum):
...@@ -53,9 +58,13 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -53,9 +58,13 @@ class ZeroOptimizer(ColossalaiOptimizer):
backoff_factor: float = 0.5, backoff_factor: float = 0.5,
growth_interval: int = 1000, growth_interval: int = 1000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32): max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0,
**defaults: Any):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ZeroDDP) assert isinstance(module, ZeroDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
self.module = module self.module = module
self.gemini_manager = module.gemini_manager self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
...@@ -63,11 +72,17 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -63,11 +72,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set() self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, module.fp32_params): for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set: if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
self.chunk16_set.add(chunk_16) self.chunk16_set.add(chunk_16)
self.__init__optimizer() self.__init__optimizer()
...@@ -125,13 +140,49 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -125,13 +140,49 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self._found_overflow.item() > 0 return self._found_overflow.item() > 0
def _unscale_grads(self): def _calc_global_norm(self) -> float:
assert self.optim_state == OptimState.SCALED norm_sqr: float = 0.0
for group in self.optim.param_groups: group_to_norm = dict()
for p in group['params']: for c16 in self.chunk16_set:
if p.grad is not None: assert c16.l2_norm is not None
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED if c16.is_gathered:
norm_sqr += c16.l2_norm
else:
# this chunk is sharded, use communication to collect total norm
if c16.torch_pg not in group_to_norm:
group_to_norm[c16.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm
c16.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
norm_sqr += comm_buffer.item()
global_norm = math.sqrt(norm_sqr)
return global_norm
def _get_combined_scale(self):
loss_scale = 1
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property @property
def loss_scale(self): def loss_scale(self):
...@@ -144,17 +195,22 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -144,17 +195,22 @@ class ZeroOptimizer(ColossalaiOptimizer):
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
self._maybe_move_fp32_params() self._maybe_move_fp32_params()
self._set_grad_ptr() self._set_grad_ptr()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
found_inf = self._check_overflow() found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf: if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step') self._logger.info(f'Found overflow. Skip step')
self.zero_grad() self.zero_grad() # reset all gradients
self._update_fp16_params() self._update_fp16_params()
return return
ret = self.optim.step(*args, **kwargs)
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states() self._register_states()
self.zero_grad() self.zero_grad()
self._update_fp16_params() self._update_fp16_params()
...@@ -219,6 +275,8 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -219,6 +275,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter): def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param] param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin) begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end return begin, end
......
from .data_parallel import ColoDDP, ZeroDDP from .data_parallel import ColoDDP, ZeroDDP
from .gemini_parallel import GeminiDDP
__all__ = ['ColoDDP', 'ZeroDDP'] __all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']
import torch
import itertools import itertools
import torch.distributed as dist from collections import OrderedDict
from functools import partial from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set from typing import Dict, Iterable, List, Optional, Set
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from collections import OrderedDict from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager from .reducer import Reducer
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from .utils import get_static_torch_model
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
...@@ -185,25 +190,16 @@ class ColoDDP(torch.nn.Module): ...@@ -185,25 +190,16 @@ class ColoDDP(torch.nn.Module):
class ZeroDDP(ColoDDP): class ZeroDDP(ColoDDP):
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now. """ZeRO DDP for ColoTensor.
We can configure chunk and gemini via ChunkManager and GeminiManager respectively. Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Example:
>>> model = torch.nn.Linear(20, 1)
>>> placement_policy = 'cuda'
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
>>> model = ZeroDDP(model, gemini_manager)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args: Args:
module (torch.nn.Module): Module to apply ZeRO-DP. module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``. For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False. force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
""" """
...@@ -216,13 +212,24 @@ class ZeroDDP(ColoDDP): ...@@ -216,13 +212,24 @@ class ZeroDDP(ColoDDP):
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager) self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = [] self.fp32_params: List[ColoTensor] = []
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = {}
# TODO: get param order and filter unused params cpu_offload = self.gemini_manager.policy_name != 'cuda'
for p in module.parameters():
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
for p in param_order.generate():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False): if getattr(p, '_ddp_to_ignore', False):
...@@ -232,28 +239,40 @@ class ZeroDDP(ColoDDP): ...@@ -232,28 +239,40 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float() fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half() p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size() dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) self.chunk_manager.register_tensor(tensor=p,
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
self._cast_buffers() self._cast_buffers()
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, self.fp32_params): for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16) chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger() self._logger = get_dist_logger()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter() self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook): with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32: if self.force_outputs_fp32:
return _cast_float(outputs, torch.float) return _cast_float(outputs, torch.float)
...@@ -266,7 +285,9 @@ class ZeroDDP(ColoDDP): ...@@ -266,7 +285,9 @@ class ZeroDDP(ColoDDP):
p.grad = None p.grad = None
def _post_backward(self): def _post_backward(self):
assert self.chunk_manager.accessed_mem == 0 if self.chunk_manager.accessed_mem != 0:
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.")
self._setup_grads_ptr() self._setup_grads_ptr()
self._logger.debug( self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
...@@ -274,12 +295,12 @@ class ZeroDDP(ColoDDP): ...@@ -274,12 +295,12 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter() self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad) torch.autograd.backward(tensor, grad)
self._post_backward() self._post_backward()
...@@ -287,16 +308,21 @@ class ZeroDDP(ColoDDP): ...@@ -287,16 +308,21 @@ class ZeroDDP(ColoDDP):
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
free_storage(empty_grad) free_storage(empty_grad)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk = self.chunk_manager.get_chunk(p) chunk = self.chunk_manager.get_chunk(p)
assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad) chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk) reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced: if reduced:
if chunk.is_gathered: if chunk.is_gathered:
chunk.chunk_total.div_(chunk.pg_size) chunk.cuda_global_chunk.div_(chunk.pg_size)
else: else:
chunk.cuda_shard.div_(chunk.pg_size) chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad return empty_grad
...@@ -307,12 +333,10 @@ class ZeroDDP(ColoDDP): ...@@ -307,12 +333,10 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors(): for tensor in chunk.get_tensors():
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
r"""Returns a dictionary containing a whole state of the module. """
Args:
Both parameters and persistent buffers (e.g. running averages) are strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()`
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Returns: Returns:
dict: dict:
...@@ -322,7 +346,30 @@ class ZeroDDP(ColoDDP): ...@@ -322,7 +346,30 @@ class ZeroDDP(ColoDDP):
>>> module.state_dict().keys() >>> module.state_dict().keys()
['bias', 'weight'] ['bias', 'weight']
"""
if strict:
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0)
return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)
def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
""" """
if destination is None: if destination is None:
destination = OrderedDict() destination = OrderedDict()
...@@ -336,24 +383,20 @@ class ZeroDDP(ColoDDP): ...@@ -336,24 +383,20 @@ class ZeroDDP(ColoDDP):
destination = hook_result destination = hook_result
return destination return destination
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
r"""Saves module state to `destination` dictionary, containing a state """
of the module, but not its descendants. This is called on every get param content from chunks.
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args: Args:
destination (dict): a dict where state will be stored param_list (_type_): a list of torch.nn.Parameters
prefix (str): the prefix for parameters and buffers used in this only_rank_0 (_type_): _description_
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters # save parameters
param_to_save_data = dict() param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(self.fp32_params) chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list: for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk) temp_chunk = get_temp_total_chunk_on_cuda(chunk)
...@@ -367,7 +410,25 @@ class ZeroDDP(ColoDDP): ...@@ -367,7 +410,25 @@ class ZeroDDP(ColoDDP):
param_to_save_data[tensor] = record_tensor param_to_save_data[tensor] = record_tensor
del temp_chunk del temp_chunk
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None: if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
...@@ -519,7 +580,7 @@ class ZeroDDP(ColoDDP): ...@@ -519,7 +580,7 @@ class ZeroDDP(ColoDDP):
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered: if chunk.is_gathered:
chunk.chunk_total.copy_(temp_chunk) chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None: elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment