Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
653b0a62
Commit
653b0a62
authored
Nov 09, 2022
by
zbian
Committed by
アマデウス
Nov 09, 2022
Browse files
added skip_bias_add for non-tp linear
parent
e5b1a0c9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
493 additions
and
440 deletions
+493
-440
colossalai/nn/layer/colossalai_layer/linear.py
colossalai/nn/layer/colossalai_layer/linear.py
+141
-147
colossalai/nn/layer/vanilla/__init__.py
colossalai/nn/layer/vanilla/__init__.py
+11
-3
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+341
-290
No files found.
colossalai/nn/layer/colossalai_layer/linear.py
View file @
653b0a62
import
math
import
inspect
import
math
from
typing
import
Callable
from
colossalai.utils
import
get_current_device
from
torch
import
dtype
,
nn
from
colossalai.utils
import
get_current_device
from
...
import
init
as
init
from
..parallel_1d
import
*
from
..parallel_2d
import
*
...
...
@@ -14,7 +15,7 @@ from ..utils import get_tensor_parallel_mode
from
..vanilla
import
*
from
._utils
import
ColossalaiModule
_parallel_linear
=
{
'1d'
:
Linear1D
,
'2d'
:
Linear2D
,
'2.5d'
:
Linear2p5D
,
'3d'
:
Linear3D
}
_parallel_linear
=
{
None
:
VanillaLinear
,
'1d'
:
Linear1D
,
'2d'
:
Linear2D
,
'2.5d'
:
Linear2p5D
,
'3d'
:
Linear3D
}
_parallel_classifier
=
{
None
:
VanillaClassifier
,
...
...
@@ -73,16 +74,9 @@ class Linear(ColossalaiModule):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
**
kwargs
)
->
None
:
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
:
layer
=
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
).
to
(
dtype
).
to
(
get_current_device
())
weight_initializer
(
layer
.
weight
,
fan_in
=
in_features
,
fan_out
=
out_features
)
if
layer
.
bias
is
not
None
:
bias_initializer
(
layer
.
bias
,
fan_in
=
in_features
)
else
:
linear_cls
=
_parallel_linear
[
tensor_parallel
]
gather_output
=
kwargs
.
pop
(
'gather_output'
,
None
)
if
'gather_output'
in
inspect
.
signature
(
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
if
'gather_output'
in
inspect
.
signature
(
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
kwargs
[
'gather_output'
]
=
gather_output
layer
=
linear_cls
(
in_features
,
...
...
colossalai/nn/layer/vanilla/__init__.py
View file @
653b0a62
from
.layers
import
(
DropPath
,
VanillaClassifier
,
VanillaLayerNorm
,
VanillaPatchEmbedding
,
WrappedDropout
,
WrappedDropPath
)
from
.layers
import
(
DropPath
,
VanillaClassifier
,
VanillaLayerNorm
,
VanillaLinear
,
VanillaPatchEmbedding
,
WrappedDropout
,
WrappedDropPath
,
)
__all__
=
[
"VanillaLayerNorm"
,
"VanillaPatchEmbedding"
,
"VanillaClassifier"
,
"DropPath"
,
"WrappedDropout"
,
"WrappedDropPath"
"VanillaLayerNorm"
,
"VanillaPatchEmbedding"
,
"VanillaClassifier"
,
"DropPath"
,
"WrappedDropout"
,
"WrappedDropPath"
,
"VanillaLinear"
]
colossalai/nn/layer/vanilla/layers.py
View file @
653b0a62
...
...
@@ -3,12 +3,14 @@ from typing import Callable
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
torch.nn.parameter
import
Parameter
from
colossalai.context
import
seed
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch
import
nn
as
nn
from
..utils
import
to_2tuple
...
...
@@ -288,3 +290,52 @@ class VanillaLayerNorm(nn.Module):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
@
LAYERS
.
register_module
class
VanillaLinear
(
nn
.
Module
):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
skip_bias_add
=
skip_bias_add
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
in_features
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
weight_initializer
(
self
.
weight
,
fan_in
=
in_features
,
fan_out
=
out_features
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
in_features
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
not
self
.
skip_bias_add
:
return
F
.
linear
(
input
,
self
.
weight
,
self
.
bias
)
else
:
return
F
.
linear
(
input
,
self
.
weight
),
self
.
bias
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment