Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
763dc325
Unverified
Commit
763dc325
authored
Mar 30, 2022
by
Ziyue Jiang
Committed by
GitHub
Mar 30, 2022
Browse files
[TP] Add gather_out arg to Linear (#541)
parent
8c90d4df
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
9 deletions
+14
-9
colossalai/nn/layer/colossalai_layer/linear.py
colossalai/nn/layer/colossalai_layer/linear.py
+14
-9
No files found.
colossalai/nn/layer/colossalai_layer/linear.py
View file @
763dc325
import
math
import
inspect
from
typing
import
Callable
from
colossalai.utils
import
get_current_device
...
...
@@ -78,7 +79,11 @@ class Linear(nn.Module):
if
self
.
layer
.
bias
is
not
None
:
bias_initializer
(
self
.
layer
.
bias
,
fan_in
=
in_features
)
else
:
self
.
layer
=
_parallel_linear
[
tensor_parallel
](
linear_cls
=
_parallel_linear
[
tensor_parallel
]
gather_output
=
kwargs
.
pop
(
'gather_output'
,
None
)
if
'gather_output'
in
inspect
.
signature
(
linear_cls
.
__init__
).
parameters
.
keys
():
# gather_out arg is available
kwargs
[
'gather_output'
]
=
gather_output
self
.
layer
=
linear_cls
(
in_features
,
out_features
,
bias
=
bias
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment