Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c336cd30
Commit
c336cd30
authored
Apr 02, 2022
by
FredHuang99
Committed by
binmakeswell
Apr 06, 2022
Browse files
[NFC] polish colossalai/communication/utils.py code style (#656)
parent
5ab9a712
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
6 deletions
+2
-6
colossalai/communication/utils.py
colossalai/communication/utils.py
+2
-6
No files found.
colossalai/communication/utils.py
View file @
c336cd30
...
@@ -77,9 +77,7 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
...
@@ -77,9 +77,7 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
start_index
=
partition_size
*
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
start_index
=
partition_size
*
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
if
new_buffer
:
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
data
.
copy_
(
tensor
.
view
(
-
1
)[
start_index
:
end_index
])
else
:
else
:
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
data
=
tensor
.
view
(
-
1
)[
start_index
:
end_index
]
...
@@ -97,9 +95,7 @@ def gather_split_1d_tensor(tensor):
...
@@ -97,9 +95,7 @@ def gather_split_1d_tensor(tensor):
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
numel
=
torch
.
numel
(
tensor
)
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
dist
.
all_gather
(
chunks
,
tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
dist
.
all_gather
(
chunks
,
tensor
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
return
gathered
return
gathered
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment