Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4dc605c9
Unverified
Commit
4dc605c9
authored
Feb 04, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 04, 2021
Browse files
[perf] ShardedDDP - small memory use reduction - minor speedup (#366)
* minor * minor
parent
42e44149
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
29 deletions
+40
-29
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+40
-29
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4dc605c9
...
...
@@ -142,6 +142,7 @@ class ShardedDataParallel(nn.Module):
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_bucket_list
:
Optional
[
List
[
Bucket
]]
=
None
self
.
_setup_bucket_strategy
()
# - setup backward hooks which will be called by Torch's autograd in due time
...
...
@@ -155,6 +156,8 @@ class ShardedDataParallel(nn.Module):
if
sync_models_at_startup
:
self
.
_sync_params_and_buffers
()
self
.
_clear_counters
()
def
forward
(
self
,
*
inputs
:
Any
,
**
kwargs
:
Any
)
->
Any
:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
...
...
@@ -256,15 +259,16 @@ class ShardedDataParallel(nn.Module):
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
for
optimizer
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
[
optimizer
].
keys
():
for
bucket
in
self
.
buckets
[
optimizer
][
device
]:
assert
bucket
.
sent
,
(
"A bucket failed to be sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
if
self
.
use_buckets
:
assert
self
.
_bucket_list
is
not
None
for
bucket
in
self
.
_bucket_list
:
assert
bucket
.
sent
,
(
"A bucket failed to be sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
bucket
.
reset
()
bucket
.
reset
()
def
_find_rank
(
self
,
param
:
Parameter
)
->
Tuple
[
OSS
,
int
]:
""" Look up where this parameter belongs to """
...
...
@@ -374,19 +378,16 @@ class ShardedDataParallel(nn.Module):
def
bucket_flush
(
*
unused
:
Any
)
->
None
:
handle
=
None
for
bucket_optim
in
self
.
buckets
.
values
():
for
bucket_rank
in
bucket_optim
.
values
():
for
bucket
in
bucket_rank
:
if
not
bucket
.
sent
:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
sent
=
True
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
async_op
=
True
,
)
assert
self
.
_bucket_list
is
not
None
for
bucket
in
self
.
_bucket_list
:
if
not
bucket
.
sent
:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
sent
=
True
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
async_op
=
True
,
)
# Only wait on the last handle
if
handle
:
...
...
@@ -430,19 +431,19 @@ class ShardedDataParallel(nn.Module):
if
not
self
.
use_buckets
:
return
# - Allocate one buffer per rank and per device to group the small parameters
for
sharded_optimizer
in
self
.
sharded_optimizers
:
for
device
,
per_device
in
sharded_optimizer
.
per_device_params
.
items
():
self
.
buckets
[
sharded_optimizer
][
device
]
=
[
Bucket
(
buffer
=
torch
.
zeros
(
self
.
buffer_max_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
))
for
_
in
per_device
]
# Devise the bucketing strategy
for
sharded_optimizer
in
self
.
sharded_optimizers
:
for
device
,
per_rank_params
in
sharded_optimizer
.
per_device_params
.
items
():
self
.
buckets
[
sharded_optimizer
][
device
]
=
[]
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
offset
=
0
self
.
buckets
[
sharded_optimizer
][
device
].
append
(
Bucket
(
buffer
=
torch
.
zeros
(
self
.
buffer_max_size
,
dtype
=
per_rank_params
[
0
][
0
].
dtype
,
device
=
device
)
)
)
bucket
=
self
.
buckets
[
sharded_optimizer
][
device
][
dst_rank
]
bucket
.
destination
=
dst_rank
...
...
@@ -473,3 +474,13 @@ class ShardedDataParallel(nn.Module):
bucket
.
buffer
.
resize_
(
offset
)
if
bucket
.
max_params_checked_in
>
0
:
self
.
_reduced_grads_max
[
sharded_optimizer
]
+=
1
# one reduce call per bucket
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
buckets
[
sharded_optimizer
][
device
]
for
sharded_optimizer
in
self
.
sharded_optimizers
for
device
in
sharded_optimizer
.
per_device_params
.
keys
()
]
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment