Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
0cbf3bab
Unverified
Commit
0cbf3bab
authored
Mar 09, 2021
by
Myle Ott
Committed by
GitHub
Mar 09, 2021
Browse files
[perf] Further improve performance for FSDP.no_sync (#502)
parent
aa9129a3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
14 deletions
+61
-14
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+4
-4
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_no_sync.py
+57
-10
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
0cbf3bab
...
@@ -901,11 +901,11 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -901,11 +901,11 @@ class FullyShardedDataParallel(nn.Module):
if
param
.
grad
.
requires_grad
:
if
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"FullyShardedDataParallel only works with gradients that don't require grad"
)
raise
RuntimeError
(
"FullyShardedDataParallel only works with gradients that don't require grad"
)
if
not
self
.
_is_root
or
self
.
_require_backward_grad_sync
:
if
self
.
_require_backward_grad_sync
or
self
.
reshard_after_forward
:
# Free full params. As a special case, we don't free the full params
# Free full params. As a special case, we don't free the full params
#
on the root instance
when in a ``no_sync`` context (as indicated
# when in a ``no_sync`` context (as
inversely
indicated
by
#
by
``self._require_backward_grad_sync``), since
we will need the
# ``self._require_backward_grad_sync``), since
the params will not
#
params again immediately
for the next forward.
#
get updated be
for
e
the next forward.
self
.
_free_full_params
([
param
])
self
.
_free_full_params
([
param
])
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
...
...
tests/nn/data_parallel/test_fsdp_no_sync.py
View file @
0cbf3bab
...
@@ -4,15 +4,17 @@
...
@@ -4,15 +4,17 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
functools
import
functools
import
itertools
import
unittest
import
unittest
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
parameterized
import
parameterized
import
torch
import
torch
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
DummyProcessGroup
,
objects_are_equal
from
fairscale.utils.testing
import
DummyProcessGroup
,
objects_are_equal
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
spawn_and_init
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
rename_test
,
spawn_and_init
class
TestNoSync
(
DistributedTest
):
class
TestNoSync
(
DistributedTest
):
...
@@ -94,21 +96,60 @@ class TestNoSync(DistributedTest):
...
@@ -94,21 +96,60 @@ class TestNoSync(DistributedTest):
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
assert
objects_are_equal
(
ref_grads
,
accumulated_grads
,
raise_exception
=
True
)
keys
=
[
"reshard_after_forward"
,
"mixed_precision"
]
COMM_CONFIG_OPTIONS
=
[[
dict
(
zip
(
keys
,
config
))]
for
config
in
itertools
.
product
([
True
,
False
],
repeat
=
len
(
keys
))]
class
TestNoSyncCommunication
(
DistributedTest
):
class
TestNoSyncCommunication
(
DistributedTest
):
def
test_communication
(
self
):
@
parameterized
.
expand
(
COMM_CONFIG_OPTIONS
,
name_func
=
rename_test
)
config
=
{
"mixed_precision"
:
True
}
def
test_communication
(
self
,
config
):
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
)
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
)
spawn_and_init
(
fn
)
spawn_and_init
(
fn
)
@
parameterized
.
expand
(
COMM_CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_communication_nested
(
self
,
config
):
fn
=
functools
.
partial
(
self
.
_test_communication
,
config
=
config
,
nested_model
=
True
)
spawn_and_init
(
fn
)
@
classmethod
@
classmethod
def
_test_communication
(
self
,
rank
,
group
,
config
):
def
_test_communication
(
self
,
rank
,
group
,
config
,
nested_model
=
False
):
if
group
.
size
()
==
1
:
if
group
.
size
()
==
1
:
return
return
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
# Turn off bucketing to accurately count number of reduce_scatters.
config
[
"bucket_cap_mb"
]
=
0
if
nested_model
:
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
else
:
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
num_fsdp
=
0
for
child
in
model
.
modules
():
# includes self
if
isinstance
(
child
,
FullyShardedDataParallel
)
and
len
(
child
.
params
)
>
0
:
num_fsdp
+=
1
if
config
.
get
(
"reshard_after_forward"
,
True
):
# inside no_sync:
# num_fsdp all-gathers in the forward
# num_fsdp-1 all-gathers in the backward (except root)
# outside no_sync:
# num_fsdp-1 all-gathers in the forward (except root)
# num_fsdp-1 all-gathers in the backward (except root)
expected_all_gather1
=
2
*
num_fsdp
-
1
expected_all_gather2
=
expected_all_gather1
+
(
2
*
num_fsdp
-
2
)
else
:
# inside no_sync:
# num_fsdp all-gathers in the forward
# outside no_sync:
# none
expected_all_gather1
=
num_fsdp
expected_all_gather2
=
num_fsdp
expected_reduce_scatter
=
num_fsdp
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
patch
(
"torch.distributed.all_gather"
)
as
mock_all_gather
:
with
patch
(
"torch.distributed.all_gather"
)
as
mock_all_gather
:
with
patch
(
"torch.distributed.reduce_scatter"
)
as
mock_reduce_scatter
:
with
patch
(
"torch.distributed.reduce_scatter"
)
as
mock_reduce_scatter
:
with
model
.
no_sync
():
with
model
.
no_sync
():
...
@@ -116,15 +157,21 @@ class TestNoSyncCommunication(DistributedTest):
...
@@ -116,15 +157,21 @@ class TestNoSyncCommunication(DistributedTest):
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
(
assert
mock_reduce_scatter
.
call_count
==
0
mock_all_gather
.
call_count
==
expected_all_gather1
),
f
"
{
mock_all_gather
.
call_count
}
!=
{
expected_all_gather1
}
"
assert
mock_reduce_scatter
.
call_count
==
0
,
f
"
{
mock_reduce_scatter
.
call_count
}
!= 0"
output
=
model
(
*
batch
)
output
=
model
(
*
batch
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
=
model
.
module
.
get_loss
(
batch
,
output
)
loss
.
backward
()
loss
.
backward
()
assert
mock_all_gather
.
call_count
==
1
assert
(
assert
mock_reduce_scatter
.
call_count
==
1
mock_all_gather
.
call_count
==
expected_all_gather2
),
f
"
{
mock_all_gather
.
call_count
}
!=
{
expected_all_gather2
}
"
assert
(
mock_reduce_scatter
.
call_count
==
expected_reduce_scatter
),
f
"
{
mock_reduce_scatter
.
call_count
}
!=
{
expected_reduce_scatter
}
"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment