Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
0cbf3bab
Unverified
Commit
0cbf3bab
authored
Mar 09, 2021
by
Myle Ott
Committed by
GitHub
Mar 09, 2021
Browse files
[perf] Further improve performance for FSDP.no_sync (#502)
parent
aa9129a3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
14 deletions
+61
-14
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+4
-4
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_no_sync.py
+57
-10
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
0cbf3bab
...
...
@@ -901,11 +901,11 @@ class FullyShardedDataParallel(nn.Module):
if
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"FullyShardedDataParallel only works with gradients that don't require grad"
)
if
not
self
.
_is_root
or
self
.
_require_backward_grad_sync
:
if
self
.
_require_backward_grad_sync
or
self
.
reshard_after_forward
:
# Free full params. As a special case, we don't free the full params
#
on the root instance
when in a ``no_sync`` context (as indicated
#
by
``self._require_backward_grad_sync``), since
we will need the
#
params again immediately
for the next forward.
# when in a ``no_sync`` context (as
inversely
indicated
by
# ``self._require_backward_grad_sync``), since
the params will not
#
get updated be
for
e
the next forward.
self
.
_free_full_params
([
param
])
if
self
.
mixed_precision
:
...
...
tests/nn/data_parallel/test_fsdp_no_sync.py
View file @
0cbf3bab
...
...
@@ -4,15 +4,17 @@
# LICENSE file in the root directory of this source tree.
import
functools
import
itertools
import
unittest
from
unittest.mock
import
patch
from
parameterized
import
parameterized
import
torch
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
DummyProcessGroup
,
objects_are_equal
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
spawn_and_init
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
rename_test
,
spawn_and_init
class
TestNoSync
(
DistributedTest
):
...
...
@@ -94,21 +96,60 @@ class TestNoSync(DistributedTest):
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
keys
=
[
"reshard_after_forward"
,
"mixed_precision"
]
COMM_CONFIG_OPTIONS
=
[[
dict
(
zip
(
keys
,
config
))]
for
config
in
itertools
.
product
([
True
,
False
],
repeat
=
len
(
keys
))]
class
TestNoSyncCommunication
(
DistributedTest
):
def
test_communication
(
self
):
config
=
{
"mixed_precision"
:
True
}
@
parameterized
.
expand
(
COMM_CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_communication
(
self
,
config
):
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
)
spawn_and_init
(
fn
)
@
parameterized
.
expand
(
COMM_CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_communication_nested
(
self
,
config
):
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
,
nested_model
=
True
)
spawn_and_init
(
fn
)
@
classmethod
def
_test_communication
(
self
,
rank
,
group
,
config
):
def
_test_communication
(
self
,
rank
,
group
,
config
,
nested_model
=
False
):
if
group
.
size
()
==
1
:
return
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
# Turn off bucketing to accurately count number of reduce_scatters.
config
[
"bucket_cap_mb"
]
=
0
if
nested_model
:
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
else
:
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
num_fsdp
=
0
for
child
in
model
.
modules
():
# includes self
if
isinstance
(
child
,
FullyShardedDataParallel
)
and
len
(
child
.
params
)
>
0
:
num_fsdp
+=
1
if
config
.
get
(
"reshard_after_forward"
,
True
):
# inside no_sync:
# num_fsdp all-gathers in the forward
# num_fsdp-1 all-gathers in the backward (except root)
# outside no_sync:
# num_fsdp-1 all-gathers in the forward (except root)
# num_fsdp-1 all-gathers in the backward (except root)
expected_all_gather1
=
2
*
num_fsdp
-
1
expected_all_gather2
=
expected_all_gather1
+
(
2
*
num_fsdp
-
2
)
else
:
# inside no_sync:
# num_fsdp all-gathers in the forward
# outside no_sync:
# none
expected_all_gather1
=
num_fsdp
expected_all_gather2
=
num_fsdp
expected_reduce_scatter
=
num_fsdp
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
patch
(
"torch.distributed.all_gather"
)
as
mock_all_gather
:
with
patch
(
"torch.distributed.reduce_scatter"
)
as
mock_reduce_scatter
:
with
model
.
no_sync
():
...
...
@@ -116,15 +157,21 @@ class TestNoSyncCommunication(DistributedTest):
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
mock_reduce_scatter
.
call_count
==
0
assert
(
mock_all_gather
.
call_count
==
expected_all_gather1
),
f
"
{
mock_all_gather
.
call_count
}
!=
{
expected_all_gather1
}
"
assert
mock_reduce_scatter
.
call_count
==
0
,
f
"
{
mock_reduce_scatter
.
call_count
}
!= 0"
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
mock_reduce_scatter
.
call_count
==
1
assert
(
mock_all_gather
.
call_count
==
expected_all_gather2
),
f
"
{
mock_all_gather
.
call_count
}
!=
{
expected_all_gather2
}
"
assert
(
mock_reduce_scatter
.
call_count
==
expected_reduce_scatter
),
f
"
{
mock_reduce_scatter
.
call_count
}
!=
{
expected_reduce_scatter
}
"
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment