Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
653b0a62
Commit
653b0a62
authored
Nov 09, 2022
by
zbian
Committed by
アマデウス
Nov 09, 2022
Browse files
added skip_bias_add for non-tp linear
parent
e5b1a0c9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
493 additions
and
440 deletions
+493
-440
colossalai/nn/layer/colossalai_layer/linear.py
colossalai/nn/layer/colossalai_layer/linear.py
+141
-147
colossalai/nn/layer/vanilla/__init__.py
colossalai/nn/layer/vanilla/__init__.py
+11
-3
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+341
-290
No files found.
colossalai/nn/layer/colossalai_layer/linear.py
View file @
653b0a62
import
math
import
inspect
import
inspect
import
math
from
typing
import
Callable
from
typing
import
Callable
from
colossalai.utils
import
get_current_device
from
torch
import
dtype
,
nn
from
torch
import
dtype
,
nn
from
colossalai.utils
import
get_current_device
from
...
import
init
as
init
from
..parallel_1d
import
*
from
...
import
init
as
init
from
..parallel_2d
import
*
from
..parallel_1d
import
*
from
..parallel_2p5d
import
*
from
..parallel_2d
import
*
from
..parallel_3d
import
*
from
..parallel_2p5d
import
*
from
..utils
import
get_tensor_parallel_mode
from
..parallel_3d
import
*
from
..vanilla
import
*
from
..utils
import
get_tensor_parallel_mode
from
._utils
import
ColossalaiModule
from
..vanilla
import
*
from
._utils
import
ColossalaiModule
_parallel_linear
=
{
'1d'
:
Linear1D
,
'2d'
:
Linear2D
,
'2.5d'
:
Linear2p5D
,
'3d'
:
Linear3D
}
_parallel_linear
=
{
None
:
VanillaLinear
,
'1d'
:
Linear1D
,
'2d'
:
Linear2D
,
'2.5d'
:
Linear2p5D
,
'3d'
:
Linear3D
}
_parallel_classifier
=
{
None
:
VanillaClassifier
,
_parallel_classifier
=
{
'1d'
:
Classifier1D
,
None
:
VanillaClassifier
,
'2d'
:
Classifier2D
,
'1d'
:
Classifier1D
,
'2.5d'
:
Classifier2p5D
,
'2d'
:
Classifier2D
,
'3d'
:
Classifier3D
'2.5d'
:
Classifier2p5D
,
}
'3d'
:
Classifier3D
}
_vocab_parallel_classifier
=
{
'1d'
:
VocabParallelClassifier1D
,
_vocab_parallel_classifier
=
{
'2d'
:
VocabParallelClassifier2D
,
'1d'
:
VocabParallelClassifier1D
,
'2.5d'
:
VocabParallelClassifier2p5D
,
'2d'
:
VocabParallelClassifier2D
,
'3d'
:
VocabParallelClassifier3D
'2.5d'
:
VocabParallelClassifier2p5D
,
}
'3d'
:
VocabParallelClassifier3D
}
class
Linear
(
ColossalaiModule
):
"""Linear layer of colossalai.
class
Linear
(
ColossalaiModule
):
"""Linear layer of colossalai.
Args:
in_features (int): size of each input sample.
Args:
out_features (int): size of each output sample.
in_features (int): size of each input sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
out_features (int): size of each output sample.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
weight_initializer (:class:`typing.Callable`, optional):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer.
weight_initializer (:class:`typing.Callable`, optional):
bias_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
The initializer of bias, defaults to xavier uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
The ``kwargs`` should contain parameters below:
::
The ``kwargs`` should contain parameters below:
::
Linear1D:
gather_output: bool (optional, default to be false)
Linear1D:
skip_bias_add: bool (optional, default to be false)
gather_output: bool (optional, default to be false)
Linear2D:
skip_bias_add: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false)
Linear2D:
Linear2p5D:
skip_bias_add: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false)
Linear2p5D:
Linear3D:
skip_bias_add: bool (optional, default to be false)
None
Linear3D:
None
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
More details about ``initializer`` please refer to
"""
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
def
__init__
(
self
,
out_features
:
int
,
in_features
:
int
,
bias
:
bool
=
True
,
out_features
:
int
,
dtype
:
dtype
=
None
,
bias
:
bool
=
True
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
dtype
:
dtype
=
None
,
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
**
kwargs
)
->
None
:
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
tensor_parallel
=
get_tensor_parallel_mode
()
**
kwargs
)
->
None
:
if
tensor_parallel
is
None
:
tensor_parallel
=
get_tensor_parallel_mode
()
layer
=
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
).
to
(
dtype
).
to
(
get_current_device
())
linear_cls
=
_parallel_linear
[
tensor_parallel
]
weight_initializer
(
layer
.
weight
,
fan_in
=
in_features
,
fan_out
=
out_features
)
gather_output
=
kwargs
.
pop
(
'gather_output'
,
None
)
if
layer
.
bias
is
not
None
:
if
'gather_output'
in
inspect
.
signature
(
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
bias_initializer
(
layer
.
bias
,
fan_in
=
in_features
)
kwargs
[
'gather_output'
]
=
gather_output
else
:
layer
=
linear_cls
(
linear_cls
=
_parallel_linear
[
tensor_parallel
]
in_features
,
gather_output
=
kwargs
.
pop
(
'gather_output'
,
None
)
out_features
,
if
'gather_output'
in
inspect
.
signature
(
bias
=
bias
,
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
dtype
=
dtype
,
kwargs
[
'gather_output'
]
=
gather_output
weight_initializer
=
weight_initializer
,
layer
=
linear_cls
(
bias_initializer
=
bias_initializer
,
in_features
,
**
kwargs
,
out_features
,
)
bias
=
bias
,
super
().
__init__
(
layer
)
dtype
=
dtype
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
class
Classifier
(
ColossalaiModule
):
**
kwargs
,
"""Classifier layer of colossalai.
)
super
().
__init__
(
layer
)
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
class
Classifier
(
ColossalaiModule
):
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
"""Classifier layer of colossalai.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
Args:
weight_initializer (:class:`typing.Callable`, optional):
in_features (int): size of each input sample.
The initializer of weight, defaults to kaiming uniform initializer.
num_classes (int): number of classes.
bias_initializer (:class:`typing.Callable`, optional):
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
The initializer of bias, defaults to xavier uniform initializer.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
More details about ``initializer`` please refer to
weight_initializer (:class:`typing.Callable`, optional):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
The initializer of weight, defaults to kaiming uniform initializer.
"""
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
def
__init__
(
self
,
in_features
:
int
,
More details about ``initializer`` please refer to
num_classes
:
int
,
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
weight
:
nn
.
Parameter
=
None
,
"""
bias
:
bool
=
True
,
dtype
:
dtype
=
None
,
def
__init__
(
self
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
in_features
:
int
,
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
num_classes
:
int
,
vocab_parallel_limit
:
int
=
2048
)
->
None
:
weight
:
nn
.
Parameter
=
None
,
tensor_parallel
=
get_tensor_parallel_mode
()
bias
:
bool
=
True
,
if
num_classes
<=
vocab_parallel_limit
or
tensor_parallel
is
None
:
dtype
:
dtype
=
None
,
layer
=
_parallel_classifier
[
tensor_parallel
](
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
in_features
,
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
num_classes
,
vocab_parallel_limit
:
int
=
2048
)
->
None
:
weight
=
weight
,
tensor_parallel
=
get_tensor_parallel_mode
()
bias
=
bias
,
if
num_classes
<=
vocab_parallel_limit
or
tensor_parallel
is
None
:
dtype
=
dtype
,
layer
=
_parallel_classifier
[
tensor_parallel
](
weight_initializer
=
weight_initializer
,
in_features
,
bias_initializer
=
bias_initializer
,
num_classes
,
)
weight
=
weight
,
else
:
bias
=
bias
,
layer
=
_vocab_parallel_classifier
[
tensor_parallel
](
dtype
=
dtype
,
in_features
,
weight_initializer
=
weight_initializer
,
num_classes
,
bias_initializer
=
bias_initializer
,
weight
=
weight
,
)
bias
=
bias
,
else
:
dtype
=
dtype
,
layer
=
_vocab_parallel_classifier
[
tensor_parallel
](
weight_initializer
=
weight_initializer
,
in_features
,
bias_initializer
=
bias_initializer
,
num_classes
,
)
weight
=
weight
,
super
().
__init__
(
layer
)
bias
=
bias
,
dtype
=
dtype
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
)
super
().
__init__
(
layer
)
colossalai/nn/layer/vanilla/__init__.py
View file @
653b0a62
from
.layers
import
(
DropPath
,
VanillaClassifier
,
VanillaLayerNorm
,
VanillaPatchEmbedding
,
WrappedDropout
,
from
.layers
import
(
WrappedDropPath
)
DropPath
,
VanillaClassifier
,
VanillaLayerNorm
,
VanillaLinear
,
VanillaPatchEmbedding
,
WrappedDropout
,
WrappedDropPath
,
)
__all__
=
[
__all__
=
[
"VanillaLayerNorm"
,
"VanillaPatchEmbedding"
,
"VanillaClassifier"
,
"DropPath"
,
"WrappedDropout"
,
"WrappedDropPath"
"VanillaLayerNorm"
,
"VanillaPatchEmbedding"
,
"VanillaClassifier"
,
"DropPath"
,
"WrappedDropout"
,
"WrappedDropPath"
,
"VanillaLinear"
]
]
colossalai/nn/layer/vanilla/layers.py
View file @
653b0a62
import
math
import
math
from
typing
import
Callable
from
typing
import
Callable
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.context
import
seed
from
torch
import
Tensor
from
colossalai.nn
import
init
as
init
from
torch
import
nn
as
nn
from
colossalai.registry
import
LAYERS
from
torch.nn.parameter
import
Parameter
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
colossalai.context
import
seed
from
torch
import
nn
as
nn
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
..utils
import
to_2tuple
from
colossalai.utils.cuda
import
get_current_device
from
..utils
import
to_2tuple
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
'survival rate' as the argument.
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
Args:
'survival rate' as the argument.
drop_prob (float, optional): probability of dropping path, defaults 0.0.
training (bool, optional): whether in training progress, defaults False.
Args:
"""
drop_prob (float, optional): probability of dropping path, defaults 0.0.
if
drop_prob
==
0.
or
not
training
:
training (bool, optional): whether in training progress, defaults False.
return
x
"""
keep_prob
=
1
-
drop_prob
if
drop_prob
==
0.
or
not
training
:
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
return
x
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
keep_prob
=
1
-
drop_prob
random_tensor
.
floor_
()
# binarize
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
output
=
x
.
div
(
keep_prob
)
*
random_tensor
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
output
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Module
):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
class
DropPath
(
nn
.
Module
):
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args:
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
drop_prob (float, optional): probability of dropping path, defaults None.
"""
Args:
drop_prob (float, optional): probability of dropping path, defaults None.
def
__init__
(
self
,
drop_prob
=
None
):
"""
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
def
forward
(
self
,
x
):
self
.
drop_prob
=
drop_prob
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
WrappedDropout
(
nn
.
Module
):
r
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each
class
WrappedDropout
(
nn
.
Module
):
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of
r
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
1/(1-p) during training. This means that during evaluation the module simply computes an identity function.
some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of
Args:
1/(1-p) during training. This means that during evaluation the module simply computes an identity function.
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
Args:
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
Note:
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
Note:
"""
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
def
__init__
(
self
,
p
:
float
=
0.5
,
inplace
:
bool
=
False
,
mode
=
None
):
"""
super
().
__init__
()
if
p
<
0
or
p
>
1
:
def
__init__
(
self
,
p
:
float
=
0.5
,
inplace
:
bool
=
False
,
mode
=
None
):
raise
ValueError
(
"dropout probability has to be between 0 and 1, "
super
().
__init__
()
"but got {}"
.
format
(
p
))
if
p
<
0
or
p
>
1
:
self
.
p
=
p
raise
ValueError
(
"dropout probability has to be between 0 and 1, "
self
.
inplace
=
inplace
"but got {}"
.
format
(
p
))
if
mode
is
None
:
self
.
p
=
p
self
.
func
=
self
.
nonefunc
self
.
inplace
=
inplace
else
:
if
mode
is
None
:
self
.
func
=
self
.
normalfunc
self
.
func
=
self
.
nonefunc
self
.
mode
=
mode
else
:
self
.
func
=
self
.
normalfunc
def
nonefunc
(
self
,
inputs
):
self
.
mode
=
mode
return
F
.
dropout
(
inputs
,
self
.
p
,
self
.
training
,
self
.
inplace
)
def
nonefunc
(
self
,
inputs
):
def
normalfunc
(
self
,
inputs
):
return
F
.
dropout
(
inputs
,
self
.
p
,
self
.
training
,
self
.
inplace
)
with
seed
(
self
.
mode
):
return
F
.
dropout
(
inputs
,
self
.
p
,
self
.
training
,
self
.
inplace
)
def
normalfunc
(
self
,
inputs
):
with
seed
(
self
.
mode
):
def
forward
(
self
,
inputs
):
return
F
.
dropout
(
inputs
,
self
.
p
,
self
.
training
,
self
.
inplace
)
return
self
.
func
(
inputs
)
def
forward
(
self
,
inputs
):
return
self
.
func
(
inputs
)
class
WrappedDropPath
(
nn
.
Module
):
r
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager.
class
WrappedDropPath
(
nn
.
Module
):
r
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args:
Here, it is wrapped with the context of seed manager.
p (float, optional): probability of dropping path, defaults 0.0.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Args:
p (float, optional): probability of dropping path, defaults 0.0.
Note:
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
Note:
"""
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
def
__init__
(
self
,
p
:
float
=
0.
,
mode
=
None
):
"""
super
().
__init__
()
self
.
p
=
p
def
__init__
(
self
,
p
:
float
=
0.
,
mode
=
None
):
self
.
mode
=
mode
super
().
__init__
()
if
self
.
mode
is
None
:
self
.
p
=
p
self
.
func
=
self
.
nonefunc
self
.
mode
=
mode
else
:
if
self
.
mode
is
None
:
self
.
func
=
self
.
normalfunc
self
.
func
=
self
.
nonefunc
self
.
mode
=
mode
else
:
self
.
func
=
self
.
normalfunc
def
nonefunc
(
self
,
inputs
):
self
.
mode
=
mode
return
drop_path
(
inputs
,
self
.
p
,
self
.
training
)
def
nonefunc
(
self
,
inputs
):
def
normalfunc
(
self
,
inputs
):
return
drop_path
(
inputs
,
self
.
p
,
self
.
training
)
with
seed
(
self
.
mode
):
return
drop_path
(
inputs
,
self
.
p
,
self
.
training
)
def
normalfunc
(
self
,
inputs
):
with
seed
(
self
.
mode
):
def
forward
(
self
,
inputs
):
return
drop_path
(
inputs
,
self
.
p
,
self
.
training
)
return
self
.
func
(
inputs
)
def
forward
(
self
,
inputs
):
return
self
.
func
(
inputs
)
@
LAYERS
.
register_module
class
VanillaPatchEmbedding
(
nn
.
Module
):
r
"""
@
LAYERS
.
register_module
2D Image to Patch Embedding
class
VanillaPatchEmbedding
(
nn
.
Module
):
r
"""
Args:
2D Image to Patch Embedding
img_size (int): image size.
patch_size (int): patch size.
Args:
in_chans (int): number of channels of input image.
img_size (int): image size.
embed_size (int): size of embedding.
patch_size (int): patch size.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
in_chans (int): number of channels of input image.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
embed_size (int): size of embedding.
weight_initializer (:class:`typing.Callable`, optional):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
bias_initializer (:class:`typing.Callable`, optional):
weight_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
The initializer of weight, defaults to kaiming uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional):
bias_initializer (:class:`typing.Callable`, optional):
The initializer of position embedding, defaults to zeros initializer.
The initializer of bias, defaults to xavier uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional):
More details about initializer please refer to
The initializer of position embedding, defaults to zeros initializer.
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def
__init__
(
self
,
"""
img_size
:
int
,
patch_size
:
int
,
def
__init__
(
self
,
in_chans
:
int
,
img_size
:
int
,
embed_size
:
int
,
patch_size
:
int
,
flatten
:
bool
=
True
,
in_chans
:
int
,
dtype
:
torch
.
dtype
=
None
,
embed_size
:
int
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
flatten
:
bool
=
True
,
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
dtype
:
torch
.
dtype
=
None
,
position_embed_initializer
:
Callable
=
init
.
zeros_
()):
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
super
().
__init__
()
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
img_size
=
to_2tuple
(
img_size
)
position_embed_initializer
:
Callable
=
init
.
zeros_
()):
patch_size
=
to_2tuple
(
patch_size
)
super
().
__init__
()
self
.
img_size
=
img_size
img_size
=
to_2tuple
(
img_size
)
self
.
patch_size
=
patch_size
patch_size
=
to_2tuple
(
patch_size
)
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
img_size
=
img_size
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
patch_size
=
patch_size
self
.
flatten
=
flatten
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
weight
=
nn
.
Parameter
(
self
.
flatten
=
flatten
torch
.
empty
((
embed_size
,
in_chans
,
*
self
.
patch_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
embed_size
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
weight
=
nn
.
Parameter
(
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
1
,
embed_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
torch
.
empty
((
embed_size
,
in_chans
,
*
self
.
patch_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
pos_embed
=
nn
.
Parameter
(
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
embed_size
,
device
=
get_current_device
(),
dtype
=
dtype
))
torch
.
zeros
((
1
,
self
.
num_patches
+
1
,
embed_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
1
,
embed_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
pos_embed
=
nn
.
Parameter
(
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
,
position_embed_initializer
)
torch
.
zeros
((
1
,
self
.
num_patches
+
1
,
embed_size
),
device
=
get_current_device
(),
dtype
=
dtype
))
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
,
position_embed_initializer
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
,
position_embed_initializer
)
fan_in
,
fan_out
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
)
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
,
position_embed_initializer
):
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
fan_in
,
fan_out
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
)
position_embed_initializer
(
self
.
pos_embed
)
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
position_embed_initializer
(
self
.
pos_embed
)
B
,
C
,
H
,
W
=
input_
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
B
,
C
,
H
,
W
=
input_
.
shape
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
if
self
.
flatten
:
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
output
=
output
+
self
.
pos_embed
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
return
output
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
output
=
output
+
self
.
pos_embed
return
output
@
LAYERS
.
register_module
class
VanillaClassifier
(
nn
.
Module
):
r
"""Dense linear classifier.
@
LAYERS
.
register_module
class
VanillaClassifier
(
nn
.
Module
):
Args:
r
"""Dense linear classifier.
in_features (int): size of each input sample.
num_classes (int): number of classes.
Args:
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
in_features (int): size of each input sample.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
num_classes (int): number of classes.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer.
flatten (bool, optional): whether to flatten output tensor, defaults to True.
bias_initializer (:class:`typing.Callable`, optional):
weight_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
More details about initializer please refer to
The initializer of bias, defaults to xavier uniform initializer.
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def
__init__
(
self
,
"""
in_features
:
int
,
num_classes
:
int
,
def
__init__
(
self
,
weight
:
nn
.
Parameter
=
None
,
in_features
:
int
,
bias
:
bool
=
True
,
num_classes
:
int
,
dtype
:
torch
.
dtype
=
None
,
weight
:
nn
.
Parameter
=
None
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias
:
bool
=
True
,
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
dtype
:
torch
.
dtype
=
None
,
super
().
__init__
()
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
self
.
in_features
=
in_features
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
self
.
num_classes
=
num_classes
super
().
__init__
()
self
.
in_features
=
in_features
if
weight
is
not
None
:
self
.
num_classes
=
num_classes
self
.
weight
=
weight
self
.
has_weight
=
False
if
weight
is
not
None
:
else
:
self
.
weight
=
weight
self
.
weight
=
nn
.
Parameter
(
self
.
has_weight
=
False
torch
.
empty
(
self
.
num_classes
,
self
.
in_features
,
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
self
.
has_weight
=
True
self
.
weight
=
nn
.
Parameter
(
if
bias
:
torch
.
empty
(
self
.
num_classes
,
self
.
in_features
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
has_weight
=
True
else
:
if
bias
:
self
.
bias
=
None
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_classes
,
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
self
.
bias
=
None
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
fan_in
,
fan_out
=
self
.
in_features
,
self
.
num_classes
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
):
if
self
.
has_weight
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
num_classes
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
if
self
.
has_weight
:
if
self
.
bias
is
not
None
:
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
if
self
.
bias
is
not
None
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
return
F
.
linear
(
input_
,
self
.
weight
,
self
.
bias
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
F
.
linear
(
input_
,
self
.
weight
,
self
.
bias
)
@
LAYERS
.
register_module
class
VanillaLayerNorm
(
nn
.
Module
):
r
"""
@
LAYERS
.
register_module
Layer Normalization for colossalai
class
VanillaLayerNorm
(
nn
.
Module
):
r
"""
Args:
Layer Normalization for colossalai
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
Args:
\times \ldots \times \text{normalized_shape}[-1]]`
normalized_shape (int): input shape from an expected input of size.
If a single integer is used, it is treated as a singleton list, and this module will
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
normalize over the last dimension which is expected to be of that specific size.
\times \ldots \times \text{normalized_shape}[-1]]`
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
If a single integer is used, it is treated as a singleton list, and this module will
bias (bool, optional): Whether to add a bias, defaults to ``True``.
normalize over the last dimension which is expected to be of that specific size.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
"""
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
"""
super
().
__init__
()
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
self
.
normalized_shape
=
(
normalized_shape
,)
super
().
__init__
()
self
.
variance_epsilon
=
eps
self
.
normalized_shape
=
(
normalized_shape
,)
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
variance_epsilon
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
normalized_shape
,
**
factory_kwargs
))
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
normalized_shape
,
**
factory_kwargs
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
normalized_shape
,
**
factory_kwargs
))
else
:
if
bias
:
self
.
bias
=
None
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
normalized_shape
,
**
factory_kwargs
))
else
:
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
self
.
bias
=
None
return
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
@
LAYERS
.
register_module
class
VanillaLinear
(
nn
.
Module
):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
skip_bias_add
=
skip_bias_add
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
weight_initializer
(
self
.
weight
,
fan_in
=
in_features
,
fan_out
=
out_features
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
in_features
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
not
self
.
skip_bias_add
:
return
F
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
else
:
return
F
.
linear
(
input
,
self
.
weight
),
self
.
bias
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment