Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
116e3d0b
Commit
116e3d0b
authored
Jan 04, 2023
by
ver217
Committed by
Frank Lee
Jan 04, 2023
Browse files
[NFC] polish communication/p2p_v2.py code style (#2303)
parent
b965585d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
colossalai/communication/p2p_v2.py
colossalai/communication/p2p_v2.py
+10
-10
No files found.
colossalai/communication/p2p_v2.py
View file @
116e3d0b
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
typing
import
List
,
Tuple
,
Union
,
Any
import
pickle
import
io
import
io
import
pickle
from
typing
import
Any
,
List
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
distributed_c10d
as
c10d
from
torch.distributed
import
ProcessGroupNCCL
from
torch.distributed
import
ProcessGroupNCCL
from
torch.distributed
import
distributed_c10d
as
c10d
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
...
@@ -23,7 +23,7 @@ def init_process_group():
...
@@ -23,7 +23,7 @@ def init_process_group():
Args:
Args:
None
None
Returns:
Returns:
None
None
"""
"""
...
@@ -40,7 +40,7 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
...
@@ -40,7 +40,7 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
second_rank (int): second rank in the pair
second_rank (int): second rank in the pair
Returns:
Returns:
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
"""
"""
if
len
(
_pg_manager
)
==
0
:
if
len
(
_pg_manager
)
==
0
:
init_process_group
()
init_process_group
()
...
@@ -51,8 +51,8 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
...
@@ -51,8 +51,8 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
def
_cuda_safe_tensor_to_object
(
tensor
:
torch
.
Tensor
,
tensor_size
:
torch
.
Size
)
->
object
:
def
_cuda_safe_tensor_to_object
(
tensor
:
torch
.
Tensor
,
tensor_size
:
torch
.
Size
)
->
object
:
"""transform tensor to object with unpickle.
"""transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling
Info of the device in bytes stream will be modified into current device before unpickling
Args:
Args:
tensor (:class:`torch.tensor`): tensor to be unpickled
tensor (:class:`torch.tensor`): tensor to be unpickled
...
@@ -78,9 +78,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
...
@@ -78,9 +78,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
def
_broadcast_object_list
(
object_list
:
List
[
Any
],
src
:
int
,
dst
:
int
,
device
=
None
):
def
_broadcast_object_list
(
object_list
:
List
[
Any
],
src
:
int
,
dst
:
int
,
device
=
None
):
"""This is a modified version of the broadcast_object_list in torch.distribution
"""This is a modified version of the broadcast_object_list in torch.distribution
The only difference is that object will be move to correct device after unpickled.
The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src.
be updated with data sent from rank src.
Args:
Args:
object_list (List[Any]): list of object to broadcast
object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast
src (int): source rank to broadcast
...
@@ -182,7 +182,7 @@ def _recv_object(src: int) -> Any:
...
@@ -182,7 +182,7 @@ def _recv_object(src: int) -> Any:
Args:
Args:
src (int): source rank of data. local rank will receive data from src rank.
src (int): source rank of data. local rank will receive data from src rank.
Returns:
Returns:
Any: Object received from src.
Any: Object received from src.
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment