Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
763dc325
"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "3dbbf83f1c46ae2a3b2947e1a5925c2b8af9f7b1"
Unverified
Commit
763dc325
authored
Mar 30, 2022
by
Ziyue Jiang
Committed by
GitHub
Mar 30, 2022
Browse files
[TP] Add gather_out arg to Linear (#541)
parent
8c90d4df
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
9 deletions
+14
-9
colossalai/nn/layer/colossalai_layer/linear.py
colossalai/nn/layer/colossalai_layer/linear.py
+14
-9
No files found.
colossalai/nn/layer/colossalai_layer/linear.py
View file @
763dc325
import
math
import
inspect
from
typing
import
Callable
from
colossalai.utils
import
get_current_device
...
...
@@ -78,15 +79,19 @@ class Linear(nn.Module):
if
self
.
layer
.
bias
is
not
None
:
bias_initializer
(
self
.
layer
.
bias
,
fan_in
=
in_features
)
else
:
self
.
layer
=
_parallel_linear
[
tensor_parallel
](
in_features
,
out_features
,
bias
=
bias
,
dtype
=
dtype
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
**
kwargs
,
)
linear_cls
=
_parallel_linear
[
tensor_parallel
]
gather_output
=
kwargs
.
pop
(
'gather_output'
,
None
)
if
'gather_output'
in
inspect
.
signature
(
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
kwargs
[
'gather_output'
]
=
gather_output
self
.
layer
=
linear_cls
(
in_features
,
out_features
,
bias
=
bias
,
dtype
=
dtype
,
weight_initializer
=
weight_initializer
,
bias_initializer
=
bias_initializer
,
**
kwargs
,
)
@
property
def
weight
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment