Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2c42b230
Unverified
Commit
2c42b230
authored
Jun 02, 2022
by
アマデウス
Committed by
GitHub
Jun 02, 2022
Browse files
updated collective ops api (#1054)
parent
51b9a496
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
15 deletions
+6
-15
colossalai/communication/collective.py
colossalai/communication/collective.py
+6
-15
No files found.
colossalai/communication/collective.py
View file @
2c42b230
...
...
@@ -13,7 +13,6 @@ from colossalai.core import global_context as gpc
def
all_gather
(
tensor
:
Tensor
,
dim
:
int
,
parallel_mode
:
ParallelMode
,
on_cpu
:
bool
=
False
,
async_op
:
bool
=
False
)
->
Tensor
:
r
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
...
...
@@ -26,7 +25,6 @@ def all_gather(tensor: Tensor,
tensor (:class:`torch.Tensor`): Tensor to be gathered.
dim (int): The dimension concatenating in.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
...
...
@@ -43,7 +41,7 @@ def all_gather(tensor: Tensor,
shape
[
0
]
*=
depth
out
=
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
temp
=
list
(
torch
.
chunk
(
out
,
depth
,
dim
=
0
))
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
on_
cpu
else
gpc
.
get_group
(
parallel_mode
)
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
tensor
.
device
.
type
==
"
cpu
"
else
gpc
.
get_group
(
parallel_mode
)
work
=
dist
.
all_gather
(
tensor_list
=
temp
,
tensor
=
tensor
.
transpose
(
0
,
dim
).
contiguous
(),
group
=
group
,
...
...
@@ -59,7 +57,6 @@ def reduce_scatter(tensor: Tensor,
dim
:
int
,
parallel_mode
:
ParallelMode
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
on_cpu
:
bool
=
False
,
async_op
:
bool
=
False
)
->
Tensor
:
r
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
...
...
@@ -76,7 +73,6 @@ def reduce_scatter(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
...
...
@@ -90,7 +86,7 @@ def reduce_scatter(tensor: Tensor,
else
:
temp
=
list
(
map
(
lambda
x
:
x
.
contiguous
(),
torch
.
chunk
(
tensor
,
depth
,
dim
=
dim
)))
out
=
torch
.
empty
(
temp
[
0
].
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
on_
cpu
else
gpc
.
get_group
(
parallel_mode
)
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
tensor
.
device
.
type
==
"
cpu
"
else
gpc
.
get_group
(
parallel_mode
)
work
=
dist
.
reduce_scatter
(
output
=
out
,
input_list
=
temp
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
if
async_op
:
return
out
,
work
...
...
@@ -101,7 +97,6 @@ def reduce_scatter(tensor: Tensor,
def
all_reduce
(
tensor
:
Tensor
,
parallel_mode
:
ParallelMode
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
on_cpu
:
bool
=
False
,
async_op
:
bool
=
False
)
->
Tensor
:
r
"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
...
...
@@ -116,7 +111,6 @@ def all_reduce(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
...
...
@@ -129,7 +123,7 @@ def all_reduce(tensor: Tensor,
work
=
None
else
:
out
=
tensor
.
contiguous
()
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
on_
cpu
else
gpc
.
get_group
(
parallel_mode
)
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
tensor
.
device
.
type
==
"
cpu
"
else
gpc
.
get_group
(
parallel_mode
)
work
=
dist
.
all_reduce
(
out
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
if
async_op
:
return
out
,
work
...
...
@@ -137,7 +131,7 @@ def all_reduce(tensor: Tensor,
return
out
def
broadcast
(
tensor
:
Tensor
,
src
:
int
,
parallel_mode
:
ParallelMode
,
on_cpu
:
bool
=
False
,
async_op
:
bool
=
False
):
def
broadcast
(
tensor
:
Tensor
,
src
:
int
,
parallel_mode
:
ParallelMode
,
async_op
:
bool
=
False
):
r
"""Broadcast tensors to whole parallel group. Tensor must have the same
number of elements in all processes participating in the collective.
...
...
@@ -149,7 +143,6 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo
tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
...
...
@@ -162,7 +155,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo
work
=
None
else
:
out
=
tensor
.
contiguous
()
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
on_
cpu
else
gpc
.
get_group
(
parallel_mode
)
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
tensor
.
device
.
type
==
"
cpu
"
else
gpc
.
get_group
(
parallel_mode
)
work
=
dist
.
broadcast
(
out
,
src
=
src
,
group
=
group
,
async_op
=
async_op
)
if
async_op
:
return
out
,
work
...
...
@@ -174,7 +167,6 @@ def reduce(tensor: Tensor,
dst
:
int
,
parallel_mode
:
ParallelMode
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
on_cpu
:
bool
=
False
,
async_op
:
bool
=
False
):
r
"""Reduce tensors across whole parallel group. Only the process with
rank ``dst`` is going to receive the final result.
...
...
@@ -187,7 +179,6 @@ def reduce(tensor: Tensor,
tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
...
...
@@ -200,7 +191,7 @@ def reduce(tensor: Tensor,
work
=
None
else
:
out
=
tensor
.
contiguous
()
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
on_
cpu
else
gpc
.
get_group
(
parallel_mode
)
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
tensor
.
device
.
type
==
"
cpu
"
else
gpc
.
get_group
(
parallel_mode
)
work
=
dist
.
reduce
(
out
,
dst
=
dst
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
if
async_op
:
return
out
,
work
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment