Commit 653b0a62 authored by zbian's avatar zbian Committed by アマデウス
Browse files

added skip_bias_add for non-tp linear

parent e5b1a0c9
import math
import inspect
import math
from typing import Callable
from colossalai.utils import get_current_device
from torch import dtype, nn
from colossalai.utils import get_current_device
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
......@@ -14,7 +15,7 @@ from ..utils import get_tensor_parallel_mode
from ..vanilla import *
from ._utils import ColossalaiModule
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
......@@ -73,16 +74,9 @@ class Linear(ColossalaiModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None:
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features)
if layer.bias is not None:
bias_initializer(layer.bias, fan_in=in_features)
else:
linear_cls = _parallel_linear[tensor_parallel]
gather_output = kwargs.pop('gather_output', None)
if 'gather_output' in inspect.signature(
linear_cls.__init__).parameters.keys(): # gather_out arg is available
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
kwargs['gather_output'] = gather_output
layer = linear_cls(
in_features,
......
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout,
WrappedDropPath)
from .layers import (
DropPath,
VanillaClassifier,
VanillaLayerNorm,
VanillaLinear,
VanillaPatchEmbedding,
WrappedDropout,
WrappedDropPath,
)
__all__ = [
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath"
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath",
"VanillaLinear"
]
......@@ -3,12 +3,14 @@ from typing import Callable
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn as nn
from torch.nn.parameter import Parameter
from colossalai.context import seed
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch import nn as nn
from ..utils import to_2tuple
......@@ -288,3 +290,52 @@ class VanillaLayerNorm(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
@LAYERS.register_module
class VanillaLinear(nn.Module):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
weight_initializer(self.weight, fan_in=in_features, fan_out=out_features)
if self.bias is not None:
bias_initializer(self.bias, fan_in=in_features)
def forward(self, input: Tensor) -> Tensor:
if not self.skip_bias_add:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(input, self.weight), self.bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment