Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4dc605c9
Unverified
Commit
4dc605c9
authored
Feb 04, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 04, 2021
Browse files
[perf] ShardedDDP - small memory use reduction - minor speedup (#366)
* minor * minor
parent
42e44149
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
29 deletions
+40
-29
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+40
-29
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4dc605c9
...
@@ -142,6 +142,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -142,6 +142,7 @@ class ShardedDataParallel(nn.Module):
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
buckets
:
Dict
[
OSS
,
Dict
[
torch
.
device
,
List
[
Bucket
]]]
=
{
o
:
{}
for
o
in
self
.
sharded_optimizers
}
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_bucket_list
:
Optional
[
List
[
Bucket
]]
=
None
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
# - setup backward hooks which will be called by Torch's autograd in due time
# - setup backward hooks which will be called by Torch's autograd in due time
...
@@ -155,6 +156,8 @@ class ShardedDataParallel(nn.Module):
...
@@ -155,6 +156,8 @@ class ShardedDataParallel(nn.Module):
if
sync_models_at_startup
:
if
sync_models_at_startup
:
self
.
_sync_params_and_buffers
()
self
.
_sync_params_and_buffers
()
self
.
_clear_counters
()
def
forward
(
self
,
*
inputs
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
forward
(
self
,
*
inputs
:
Any
,
**
kwargs
:
Any
)
->
Any
:
"""
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
Module forward pass, handles any DDP-specific work in the background. Primes the
...
@@ -256,9 +259,10 @@ class ShardedDataParallel(nn.Module):
...
@@ -256,9 +259,10 @@ class ShardedDataParallel(nn.Module):
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
for
optimizer
in
self
.
buckets
.
keys
():
if
self
.
use_buckets
:
for
device
in
self
.
buckets
[
optimizer
].
keys
():
assert
self
.
_bucket_list
is
not
None
for
bucket
in
self
.
buckets
[
optimizer
][
device
]:
for
bucket
in
self
.
_bucket_list
:
assert
bucket
.
sent
,
(
assert
bucket
.
sent
,
(
"A bucket failed to be sent, probably unused parameters."
"A bucket failed to be sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
...
@@ -374,18 +378,15 @@ class ShardedDataParallel(nn.Module):
...
@@ -374,18 +378,15 @@ class ShardedDataParallel(nn.Module):
def
bucket_flush
(
*
unused
:
Any
)
->
None
:
def
bucket_flush
(
*
unused
:
Any
)
->
None
:
handle
=
None
handle
=
None
for
bucket_optim
in
self
.
bucket
s
.
values
():
assert
self
.
_
bucket
_list
is
not
None
for
bucket_rank
in
bucket_optim
.
values
():
for
bucket
in
bucket_
rank
:
for
bucket
in
self
.
_
bucket_
list
:
if
not
bucket
.
sent
:
if
not
bucket
.
sent
:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
sent
=
True
bucket
.
sent
=
True
handle
=
dist
.
reduce
(
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
async_op
=
True
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
async_op
=
True
,
)
)
# Only wait on the last handle
# Only wait on the last handle
...
@@ -430,19 +431,19 @@ class ShardedDataParallel(nn.Module):
...
@@ -430,19 +431,19 @@ class ShardedDataParallel(nn.Module):
if
not
self
.
use_buckets
:
if
not
self
.
use_buckets
:
return
return
# - Allocate one buffer per rank and per device to group the small parameters
for
sharded_optimizer
in
self
.
sharded_optimizers
:
for
device
,
per_device
in
sharded_optimizer
.
per_device_params
.
items
():
self
.
buckets
[
sharded_optimizer
][
device
]
=
[
Bucket
(
buffer
=
torch
.
zeros
(
self
.
buffer_max_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
))
for
_
in
per_device
]
# Devise the bucketing strategy
# Devise the bucketing strategy
for
sharded_optimizer
in
self
.
sharded_optimizers
:
for
sharded_optimizer
in
self
.
sharded_optimizers
:
for
device
,
per_rank_params
in
sharded_optimizer
.
per_device_params
.
items
():
for
device
,
per_rank_params
in
sharded_optimizer
.
per_device_params
.
items
():
self
.
buckets
[
sharded_optimizer
][
device
]
=
[]
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
offset
=
0
offset
=
0
self
.
buckets
[
sharded_optimizer
][
device
].
append
(
Bucket
(
buffer
=
torch
.
zeros
(
self
.
buffer_max_size
,
dtype
=
per_rank_params
[
0
][
0
].
dtype
,
device
=
device
)
)
)
bucket
=
self
.
buckets
[
sharded_optimizer
][
device
][
dst_rank
]
bucket
=
self
.
buckets
[
sharded_optimizer
][
device
][
dst_rank
]
bucket
.
destination
=
dst_rank
bucket
.
destination
=
dst_rank
...
@@ -473,3 +474,13 @@ class ShardedDataParallel(nn.Module):
...
@@ -473,3 +474,13 @@ class ShardedDataParallel(nn.Module):
bucket
.
buffer
.
resize_
(
offset
)
bucket
.
buffer
.
resize_
(
offset
)
if
bucket
.
max_params_checked_in
>
0
:
if
bucket
.
max_params_checked_in
>
0
:
self
.
_reduced_grads_max
[
sharded_optimizer
]
+=
1
# one reduce call per bucket
self
.
_reduced_grads_max
[
sharded_optimizer
]
+=
1
# one reduce call per bucket
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
buckets
[
sharded_optimizer
][
device
]
for
sharded_optimizer
in
self
.
sharded_optimizers
for
device
in
sharded_optimizer
.
per_device_params
.
keys
()
]
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment