Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3a51d909
Commit
3a51d909
authored
Mar 09, 2022
by
Cautiousss
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix format (#332)
Co-authored-by:
何晓昕
<
cautious@r-205-106-25-172.comp.nus.edu.sg
>
parent
cbb6436f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
colossalai/communication/utils.py
colossalai/communication/utils.py
+3
-3
No files found.
colossalai/communication/utils.py
View file @
3a51d909
...
...
@@ -7,7 +7,7 @@ from colossalai.utils import get_current_device
def
send_tensor_meta
(
tensor
,
need_meta
=
True
,
next_rank
=
None
):
"""Sends tensor meta information before sending a specific tensor.
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
...
...
@@ -36,7 +36,7 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
def
recv_tensor_meta
(
tensor_shape
,
prev_rank
=
None
):
"""Recieves tensor meta information before recieving a specific tensor.
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`.
...
...
@@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor):
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
dist
.
all_gather
(
chunks
,
tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
return
gathered
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment