Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
06dccdde
Commit
06dccdde
authored
Sep 08, 2022
by
Fazzie-Maqianli
Committed by
Frank Lee
Sep 08, 2022
Browse files
[NFC] polish colossalai/zero/sharded_model/reduce_scatter.py code style (#1554)
parent
2ac46f7b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
13 deletions
+13
-13
colossalai/zero/sharded_model/reduce_scatter.py
colossalai/zero/sharded_model/reduce_scatter.py
+13
-13
No files found.
colossalai/zero/sharded_model/reduce_scatter.py
View file @
06dccdde
...
...
@@ -20,6 +20,7 @@ else:
class
Bucket
:
def
__init__
(
self
,
shard_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
group
:
ProcessGroup
):
self
.
buffer
=
torch
.
zeros
((
group
.
size
(),
shard_size
),
dtype
=
dtype
,
device
=
device
)
self
.
group
=
group
...
...
@@ -34,18 +35,18 @@ class Bucket:
return
# reduce-scatter bucket
if
hasattr
(
dist
,
"_reduce_scatter_base"
)
and
enable_nccl_base_collectives
:
dist
.
_reduce_scatter_base
(
self
.
output_shard
[:
self
.
offset
],
self
.
buffer
[:,
:
self
.
offset
].
contiguous
(),
group
=
self
.
group
)
dist
.
_reduce_scatter_base
(
self
.
output_shard
[:
self
.
offset
],
self
.
buffer
[:,
:
self
.
offset
].
contiguous
(),
group
=
self
.
group
)
else
:
dist
.
reduce_scatter
(
self
.
output_shard
[:
self
.
offset
],
list
(
self
.
buffer
[:,
:
self
.
offset
].
unbind
(
0
)),
group
=
self
.
group
)
dist
.
reduce_scatter
(
self
.
output_shard
[:
self
.
offset
],
list
(
self
.
buffer
[:,
:
self
.
offset
].
unbind
(
0
)),
group
=
self
.
group
)
# execute post-reduction callbacks
for
callback_fn
in
self
.
callbacks
:
callback_fn
()
# reuse input bucket but allocate a fresh output shard
self
.
buffer
[:,
:
self
.
offset
].
zero_
()
self
.
buffer
[:,
:
self
.
offset
].
zero_
()
self
.
offset
=
0
self
.
callbacks
.
clear
()
self
.
output_shard
=
torch
.
zeros_like
(
self
.
buffer
[
0
])
...
...
@@ -73,12 +74,12 @@ class Bucket:
tensor_size
=
tensor_list
[
0
].
numel
()
stacked_input
=
torch
.
stack
(
tensor_list
).
view
(
self
.
group
.
size
(),
tensor_size
)
offset
=
self
.
offset
self
.
buffer
[:,
offset
:
offset
+
tensor_size
].
copy_
(
stacked_input
)
self
.
buffer
[:,
offset
:
offset
+
tensor_size
].
copy_
(
stacked_input
)
self
.
offset
+=
tensor_size
# callback will be given the reduced result
if
callback_fn
is
not
None
:
result_view
=
self
.
output_shard
[
offset
:
offset
+
tensor_size
].
view_as
(
tensor_list
[
0
])
result_view
=
self
.
output_shard
[
offset
:
offset
+
tensor_size
].
view_as
(
tensor_list
[
0
])
self
.
callbacks
.
append
(
functools
.
partial
(
callback_fn
,
result_view
))
...
...
@@ -141,9 +142,8 @@ class ReduceScatterBucketer:
"""
world_size
=
group
.
size
()
assert
(
len
(
input_list
)
==
world_size
),
f
"reduce_scatter received
{
len
(
input_list
)
}
inputs, expected group.size() (
{
world_size
}
)"
assert
(
len
(
input_list
)
==
world_size
),
f
"reduce_scatter received
{
len
(
input_list
)
}
inputs, expected group.size() (
{
world_size
}
)"
first_input
=
input_list
[
0
]
first_input_size
=
first_input
.
numel
()
...
...
@@ -183,7 +183,7 @@ class ReduceScatterBucketer:
@
functools
.
lru_cache
()
def
_get_shard_size
(
self
,
element_size
:
int
,
num_shards
:
int
)
->
int
:
if
self
.
bucket_size_mb
<=
0
:
# Values <= 0 disable bucketing.
if
self
.
bucket_size_mb
<=
0
:
# Values <= 0 disable bucketing.
return
0
MB
=
1024
*
1024
bucket_size
=
self
.
bucket_size_mb
*
MB
/
element_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment