Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
90019096
Unverified
Commit
90019096
authored
Mar 07, 2022
by
shenggan
Committed by
GitHub
Mar 07, 2022
Browse files
Merge pull request #4 from hpcaitech/fix_gather
fix minor bug in gather
parents
e96b76b0
77642096
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
fastfold/distributed/comm.py
fastfold/distributed/comm.py
+3
-3
No files found.
fastfold/distributed/comm.py
View file @
90019096
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch
import
Tensor
from
torch
import
Tensor
from
.core
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_
src_
rank
,
from
.core
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
.core
import
ensure_divisibility
from
.core
import
ensure_divisibility
...
@@ -33,7 +33,7 @@ def _split(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -33,7 +33,7 @@ def _split(tensor: Tensor, dim: int = -1) -> Tensor:
split_size
=
divide
(
tensor
.
shape
[
dim
],
get_tensor_model_parallel_world_size
())
split_size
=
divide
(
tensor
.
shape
[
dim
],
get_tensor_model_parallel_world_size
())
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
tensor_list
=
torch
.
split
(
tensor
,
split_size
,
dim
=
dim
)
output
=
tensor_list
[
get_tensor_model_parallel_
src_
rank
()].
contiguous
()
output
=
tensor_list
[
get_tensor_model_parallel_rank
()].
contiguous
()
return
output
return
output
...
@@ -49,7 +49,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
...
@@ -49,7 +49,7 @@ def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
tensor_list
=
output
.
chunk
(
get_tensor_model_parallel_world_size
(),
dim
=
1
)
tensor_list
=
output
.
chunk
(
get_tensor_model_parallel_world_size
(),
dim
=
1
)
dist
.
all_gather
(
list
(
tensor_list
),
tensor
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
False
)
dist
.
all_gather
(
list
(
tensor_list
),
tensor
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
False
)
else
:
else
:
tensor_list
=
[
torch
.
ones
_like
(
tensor
)
for
_
in
range
(
get_tensor_model_parallel_world_size
())]
tensor_list
=
[
torch
.
empty
_like
(
tensor
)
for
_
in
range
(
get_tensor_model_parallel_world_size
())]
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
False
)
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
False
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment